diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHints.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHints.java index 1f2da011942..3a455404a5b 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHints.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHints.java @@ -16,6 +16,7 @@ package org.springframework.ai.anthropic.aot; +import org.springframework.ai.anthropic.AnthropicChatOptions; import org.springframework.ai.anthropic.api.AnthropicApi; import org.springframework.aot.hint.MemberCategory; import org.springframework.aot.hint.RuntimeHints; @@ -37,7 +38,8 @@ public class AnthropicRuntimeHints implements RuntimeHintsRegistrar { @Override public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) { var mcs = MemberCategory.values(); - for (var tr : findJsonAnnotatedClassesInPackage(AnthropicApi.class)) { + + for (var tr : findJsonAnnotatedClassesInPackage("org.springframework.ai.anthropic")) { hints.reflection().registerType(tr, mcs); } } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHintsTests.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHintsTests.java index f38a6c8e671..4ecffac59d0 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHintsTests.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHintsTests.java @@ -16,6 +16,7 @@ package org.springframework.ai.anthropic.aot; +import java.util.HashSet; import java.util.Set; import org.junit.jupiter.api.Test; @@ -26,7 +27,6 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; -import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; class AnthropicRuntimeHintsTests { @@ -36,10 +36,23 @@ void registerHints() { AnthropicRuntimeHints anthropicRuntimeHints = new AnthropicRuntimeHints(); anthropicRuntimeHints.registerHints(runtimeHints, null); - Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage(AnthropicApi.class); + Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage("org.springframework.ai.anthropic"); + + Set registeredTypes = new HashSet<>(); + runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType())); + for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) { - assertThat(runtimeHints).matches(reflection().onType(jsonAnnotatedClass)); + assertThat(registeredTypes.contains(jsonAnnotatedClass)).isTrue(); } + + // Check a few more specific ones + assertThat(registeredTypes.contains(TypeReference.of(AnthropicApi.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(AnthropicApi.Role.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(AnthropicApi.ThinkingType.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(AnthropicApi.EventType.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(AnthropicApi.ContentBlock.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(AnthropicApi.ChatCompletionRequest.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(AnthropicApi.AnthropicMessage.class))).isTrue(); } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java index 0bec3c1c9f3..79939c62651 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java @@ -41,21 +41,8 @@ public class BedrockRuntimeHints implements RuntimeHintsRegistrar { @Override public void registerHints(RuntimeHints hints, ClassLoader classLoader) { var mcs = MemberCategory.values(); - for (var tr : findJsonAnnotatedClassesInPackage(AbstractBedrockApi.class)) { - hints.reflection().registerType(tr, mcs); - } - for (var tr : findJsonAnnotatedClassesInPackage(CohereEmbeddingBedrockApi.class)) { - hints.reflection().registerType(tr, mcs); - } - for (var tr : findJsonAnnotatedClassesInPackage(BedrockCohereEmbeddingOptions.class)) { - hints.reflection().registerType(tr, mcs); - } - - for (var tr : findJsonAnnotatedClassesInPackage(BedrockTitanEmbeddingOptions.class)) { - hints.reflection().registerType(tr, mcs); - } - for (var tr : findJsonAnnotatedClassesInPackage(TitanEmbeddingBedrockApi.class)) { + for (var tr : findJsonAnnotatedClassesInPackage("org.springframework.ai.bedrock")) { hints.reflection().registerType(tr, mcs); } } diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java index 3a87d56dd15..1024647f9a5 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java @@ -16,20 +16,21 @@ package org.springframework.ai.bedrock.aot; -import java.util.Arrays; -import java.util.List; +import java.util.HashSet; import java.util.Set; import org.junit.jupiter.api.Test; +import org.springframework.ai.bedrock.api.AbstractBedrockApi; +import org.springframework.ai.bedrock.cohere.BedrockCohereEmbeddingOptions; import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi; +import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingOptions; import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.TypeReference; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; -import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; class BedrockRuntimeHintsTests { @@ -39,15 +40,22 @@ void registerHints() { BedrockRuntimeHints bedrockRuntimeHints = new BedrockRuntimeHints(); bedrockRuntimeHints.registerHints(runtimeHints, null); - List classList = Arrays.asList(CohereEmbeddingBedrockApi.class, TitanEmbeddingBedrockApi.class); + Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage("org.springframework.ai.bedrock"); - for (Class aClass : classList) { - Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage(aClass); - for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) { - assertThat(runtimeHints).matches(reflection().onType(jsonAnnotatedClass)); - } + Set registeredTypes = new HashSet<>(); + runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType())); + + for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) { + assertThat(registeredTypes.contains(jsonAnnotatedClass)).isTrue(); } + // Check a few more specific ones + assertThat(registeredTypes.contains(TypeReference.of(AbstractBedrockApi.AmazonBedrockInvocationMetrics.class))) + .isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(CohereEmbeddingBedrockApi.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(BedrockCohereEmbeddingOptions.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(BedrockTitanEmbeddingOptions.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(TitanEmbeddingBedrockApi.class))).isTrue(); } } diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/aot/MiniMaxRuntimeHints.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/aot/MiniMaxRuntimeHints.java index 5a370bb6816..263edcdf0a4 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/aot/MiniMaxRuntimeHints.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/aot/MiniMaxRuntimeHints.java @@ -37,7 +37,8 @@ public class MiniMaxRuntimeHints implements RuntimeHintsRegistrar { @Override public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) { var mcs = MemberCategory.values(); - for (var tr : findJsonAnnotatedClassesInPackage(MiniMaxApi.class)) { + + for (var tr : findJsonAnnotatedClassesInPackage("org.springframework.ai.minimax")) { hints.reflection().registerType(tr, mcs); } } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHints.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHints.java index 3bb55073807..30967727e02 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHints.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHints.java @@ -16,13 +16,14 @@ package org.springframework.ai.mistralai.aot; -import org.springframework.ai.mistralai.api.MistralAiApi; import org.springframework.aot.hint.MemberCategory; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.RuntimeHintsRegistrar; import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; +import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; + /** * The MistralAiRuntimeHints class is responsible for registering runtime hints for * Mistral AI API classes. @@ -35,7 +36,8 @@ public class MistralAiRuntimeHints implements RuntimeHintsRegistrar { @Override public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) { var mcs = MemberCategory.values(); - for (var tr : org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage(MistralAiApi.class)) { + + for (var tr : findJsonAnnotatedClassesInPackage("org.springframework.ai.mistralai")) { hints.reflection().registerType(tr, mcs); } } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHintsTests.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHintsTests.java index ac20dbccb96..2ce0bcc56f1 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHintsTests.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHintsTests.java @@ -16,15 +16,19 @@ package org.springframework.ai.mistralai.aot; +import java.util.HashSet; import java.util.Set; import org.junit.jupiter.api.Test; +import org.springframework.ai.mistralai.MistralAiChatOptions; +import org.springframework.ai.mistralai.MistralAiEmbeddingOptions; import org.springframework.ai.mistralai.api.MistralAiApi; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.TypeReference; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; class MistralAiRuntimeHintsTests { @@ -35,11 +39,22 @@ void registerHints() { MistralAiRuntimeHints mistralAiRuntimeHints = new MistralAiRuntimeHints(); mistralAiRuntimeHints.registerHints(runtimeHints, null); - Set jsonAnnotatedClasses = org.springframework.ai.aot.AiRuntimeHints - .findJsonAnnotatedClassesInPackage(MistralAiApi.class); + Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage("org.springframework.ai.mistralai"); + + Set registeredTypes = new HashSet<>(); + runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType())); + for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) { - assertThat(runtimeHints).matches(reflection().onType(jsonAnnotatedClass)); + assertThat(registeredTypes.contains(jsonAnnotatedClass)).isTrue(); } + + // Check a few more specific ones + assertThat(registeredTypes.contains(TypeReference.of(MistralAiApi.ChatCompletion.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(MistralAiApi.ChatCompletionChunk.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(MistralAiApi.LogProbs.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(MistralAiApi.ChatCompletionFinishReason.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(MistralAiChatOptions.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(MistralAiEmbeddingOptions.class))).isTrue(); } } diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHints.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHints.java index 58c32304705..c87478e464a 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHints.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHints.java @@ -36,7 +36,7 @@ public class MoonshotRuntimeHints implements RuntimeHintsRegistrar { @Override public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) { var mcs = MemberCategory.values(); - for (var tr : findJsonAnnotatedClassesInPackage(MoonshotApi.class)) { + for (var tr : findJsonAnnotatedClassesInPackage("org.springframework.ai.moonshot")) { hints.reflection().registerType(tr, mcs); } } diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHintsTests.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHintsTests.java index 60bb11f0828..9411c397ae9 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHintsTests.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHintsTests.java @@ -16,10 +16,12 @@ package org.springframework.ai.moonshot.aot; +import java.util.HashSet; import java.util.Set; import org.junit.jupiter.api.Test; +import org.springframework.ai.moonshot.MoonshotChatOptions; import org.springframework.ai.moonshot.api.MoonshotApi; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.TypeReference; @@ -39,10 +41,21 @@ void registerHints() { MoonshotRuntimeHints moonshotRuntimeHints = new MoonshotRuntimeHints(); moonshotRuntimeHints.registerHints(runtimeHints, null); - Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage(MoonshotApi.class); + Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage("org.springframework.ai.moonshot"); + + Set registeredTypes = new HashSet<>(); + runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType())); + for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) { - assertThat(runtimeHints).matches(reflection().onType(jsonAnnotatedClass)); + assertThat(registeredTypes.contains(jsonAnnotatedClass)).isTrue(); } + + // Check a few more specific ones + assertThat(registeredTypes.contains(TypeReference.of(MoonshotApi.ChatCompletion.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(MoonshotApi.ChatCompletionRequest.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(MoonshotApi.ChatCompletionChunk.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(MoonshotApi.Usage.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(MoonshotChatOptions.class))).isTrue(); } } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/aot/OllamaRuntimeHints.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/aot/OllamaRuntimeHints.java index 057df2376a2..e33cffc3be5 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/aot/OllamaRuntimeHints.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/aot/OllamaRuntimeHints.java @@ -37,10 +37,7 @@ public class OllamaRuntimeHints implements RuntimeHintsRegistrar { @Override public void registerHints(RuntimeHints hints, ClassLoader classLoader) { var mcs = MemberCategory.values(); - for (var tr : findJsonAnnotatedClassesInPackage(OllamaApi.class)) { - hints.reflection().registerType(tr, mcs); - } - for (var tr : findJsonAnnotatedClassesInPackage(OllamaOptions.class)) { + for (var tr : findJsonAnnotatedClassesInPackage("org.springframework.ai.ollama")) { hints.reflection().registerType(tr, mcs); } } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/aot/OllamaRuntimeHintsTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/aot/OllamaRuntimeHintsTests.java index 3b030e8c6f6..8e8a03c1fae 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/aot/OllamaRuntimeHintsTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/aot/OllamaRuntimeHintsTests.java @@ -16,6 +16,7 @@ package org.springframework.ai.ollama.aot; +import java.util.HashSet; import java.util.Set; import org.junit.jupiter.api.Test; @@ -37,15 +38,20 @@ void registerHints() { OllamaRuntimeHints ollamaRuntimeHints = new OllamaRuntimeHints(); ollamaRuntimeHints.registerHints(runtimeHints, null); - Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage(OllamaApi.class); - for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) { - assertThat(runtimeHints).matches(reflection().onType(jsonAnnotatedClass)); - } + Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage("org.springframework.ai.ollama"); + + Set registeredTypes = new HashSet<>(); + runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType())); - jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage(OllamaOptions.class); for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) { - assertThat(runtimeHints).matches(reflection().onType(jsonAnnotatedClass)); + assertThat(registeredTypes.contains(jsonAnnotatedClass)).isTrue(); } + + // Check a few more specific ones + assertThat(registeredTypes.contains(TypeReference.of(OllamaApi.ChatRequest.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(OllamaApi.ChatRequest.Tool.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(OllamaApi.Message.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(OllamaOptions.class))).isTrue(); } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/aot/OpenAiRuntimeHints.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/aot/OpenAiRuntimeHints.java index 3aafbf164a0..5e7cbd54921 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/aot/OpenAiRuntimeHints.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/aot/OpenAiRuntimeHints.java @@ -41,19 +41,10 @@ public class OpenAiRuntimeHints implements RuntimeHintsRegistrar { @Override public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) { var mcs = MemberCategory.values(); - for (var tr : (findJsonAnnotatedClassesInPackage(OpenAiChatOptions.class))) { - hints.reflection().registerType(tr, mcs); - } - for (var tr : (findJsonAnnotatedClassesInPackage(OpenAiApi.class))) { - hints.reflection().registerType(tr, mcs); - } - for (var tr : (findJsonAnnotatedClassesInPackage(OpenAiAudioApi.class))) { - hints.reflection().registerType(tr, mcs); - } - for (var tr : findJsonAnnotatedClassesInPackage(OpenAiImageApi.class)) { + + for (var tr : findJsonAnnotatedClassesInPackage("org.springframework.ai.openai")) { hints.reflection().registerType(tr, mcs); } - } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/aot/OpenAiRuntimeHintsTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/aot/OpenAiRuntimeHintsTests.java index e3d849898df..5b409ad87c9 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/aot/OpenAiRuntimeHintsTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/aot/OpenAiRuntimeHintsTests.java @@ -16,15 +16,20 @@ package org.springframework.ai.openai.aot; +import java.util.HashSet; import java.util.Set; import org.junit.jupiter.api.Test; +import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.OpenAiAudioApi; +import org.springframework.ai.openai.api.OpenAiImageApi; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.TypeReference; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; class OpenAiRuntimeHintsTests { @@ -34,12 +39,26 @@ void registerHints() { OpenAiRuntimeHints openAiRuntimeHints = new OpenAiRuntimeHints(); openAiRuntimeHints.registerHints(runtimeHints, null); - Set jsonAnnotatedClasses = org.springframework.ai.aot.AiRuntimeHints - .findJsonAnnotatedClassesInPackage(OpenAiApi.class); + Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage("org.springframework.ai.openai"); + + Set registeredTypes = new HashSet<>(); + runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType())); + for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) { - assertThat(runtimeHints).matches(org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection() - .onType(jsonAnnotatedClass)); + assertThat(registeredTypes.contains(jsonAnnotatedClass)).isTrue(); } + + // Check a few more specific ones + assertThat(registeredTypes.contains(TypeReference.of(OpenAiApi.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(OpenAiAudioApi.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(OpenAiAudioApi.TtsModel.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(OpenAiAudioApi.WhisperModel.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(OpenAiImageApi.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(OpenAiApi.ChatCompletionFinishReason.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(OpenAiApi.FunctionTool.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(OpenAiApi.FunctionTool.Function.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(OpenAiApi.OutputModality.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(OpenAiChatOptions.class))).isTrue(); } } diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/aot/QianFanRuntimeHints.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/aot/QianFanRuntimeHints.java index 1de207fbf48..0a9ed025df9 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/aot/QianFanRuntimeHints.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/aot/QianFanRuntimeHints.java @@ -37,10 +37,7 @@ public class QianFanRuntimeHints implements RuntimeHintsRegistrar { @Override public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) { var mcs = MemberCategory.values(); - for (var tr : findJsonAnnotatedClassesInPackage(QianFanApi.class)) { - hints.reflection().registerType(tr, mcs); - } - for (var tr : findJsonAnnotatedClassesInPackage(QianFanImageApi.class)) { + for (var tr : findJsonAnnotatedClassesInPackage("org.springframework.ai.qianfan")) { hints.reflection().registerType(tr, mcs); } } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHints.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHints.java index 03209da5118..37806e5eac0 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHints.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHints.java @@ -16,7 +16,6 @@ package org.springframework.ai.vertexai.gemini.aot; -import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; import org.springframework.aot.hint.MemberCategory; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.RuntimeHintsRegistrar; @@ -35,7 +34,7 @@ public class VertexAiGeminiRuntimeHints implements RuntimeHintsRegistrar { @Override public void registerHints(RuntimeHints hints, ClassLoader classLoader) { var mcs = MemberCategory.values(); - for (var tr : findJsonAnnotatedClassesInPackage(VertexAiGeminiChatModel.class)) { + for (var tr : findJsonAnnotatedClassesInPackage("org.springframework.ai.vertexai.gemini")) { hints.reflection().registerType(tr, mcs); } } diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHintsTests.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHintsTests.java index 88774e9e356..aa6d7bbc854 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHintsTests.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHintsTests.java @@ -16,17 +16,17 @@ package org.springframework.ai.vertexai.gemini.aot; +import java.util.HashSet; import java.util.Set; import org.junit.jupiter.api.Test; -import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; +import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatOptions; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.TypeReference; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; -import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; /** * @author Christian Tzolov @@ -39,10 +39,18 @@ void registerHints() { RuntimeHints runtimeHints = new RuntimeHints(); VertexAiGeminiRuntimeHints vertexAiGeminiRuntimeHints = new VertexAiGeminiRuntimeHints(); vertexAiGeminiRuntimeHints.registerHints(runtimeHints, null); - Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage(VertexAiGeminiChatModel.class); + + Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage( + "org.springframework.ai.vertexai.gemini"); + + Set registeredTypes = new HashSet<>(); + runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType())); + for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) { - assertThat(runtimeHints).matches(reflection().onType(jsonAnnotatedClass)); + assertThat(registeredTypes.contains(jsonAnnotatedClass)).isTrue(); } + + assertThat(registeredTypes.contains(TypeReference.of(VertexAiGeminiChatOptions.class))).isTrue(); } } diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHints.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHints.java index 34799ecf8b8..c97f66246d0 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHints.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHints.java @@ -37,10 +37,7 @@ public class WatsonxAiRuntimeHints implements RuntimeHintsRegistrar { @Override public void registerHints(RuntimeHints hints, ClassLoader classLoader) { var mcs = MemberCategory.values(); - for (var tr : findJsonAnnotatedClassesInPackage(WatsonxAiApi.class)) { - hints.reflection().registerType(tr, mcs); - } - for (var tr : findJsonAnnotatedClassesInPackage(WatsonxAiChatOptions.class)) { + for (var tr : findJsonAnnotatedClassesInPackage("org.springframework.ai.watsonx")) { hints.reflection().registerType(tr, mcs); } diff --git a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHintsTest.java b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHintsTest.java index 7d82cd7a4e3..5252ba3c65c 100644 --- a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHintsTest.java +++ b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHintsTest.java @@ -16,12 +16,19 @@ package org.springframework.ai.watsonx.aot; +import java.util.HashSet; import java.util.Set; import org.junit.jupiter.api.Test; import org.springframework.ai.watsonx.WatsonxAiChatOptions; import org.springframework.ai.watsonx.api.WatsonxAiApi; +import org.springframework.ai.watsonx.api.WatsonxAiChatRequest; +import org.springframework.ai.watsonx.api.WatsonxAiChatResponse; +import org.springframework.ai.watsonx.api.WatsonxAiChatResults; +import org.springframework.ai.watsonx.api.WatsonxAiEmbeddingRequest; +import org.springframework.ai.watsonx.api.WatsonxAiEmbeddingResponse; +import org.springframework.ai.watsonx.api.WatsonxAiEmbeddingResults; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.TypeReference; @@ -38,18 +45,24 @@ public class WatsonxAiRuntimeHintsTest { @Test void registerHints() { RuntimeHints runtimeHints = new RuntimeHints(); - WatsonxAiRuntimeHints watsonxAIRuntimeHintsTest = new WatsonxAiRuntimeHints(); - watsonxAIRuntimeHintsTest.registerHints(runtimeHints, null); + WatsonxAiRuntimeHints watsonxAiRuntimeHints = new WatsonxAiRuntimeHints(); + watsonxAiRuntimeHints.registerHints(runtimeHints, null); - Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage(WatsonxAiApi.class); - for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) { - assertThat(runtimeHints).matches(reflection().onType(jsonAnnotatedClass)); - } + Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage("org.springframework.ai.watsonx"); + + Set registeredTypes = new HashSet<>(); + runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType())); - jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage(WatsonxAiChatOptions.class); for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) { - assertThat(runtimeHints).matches(reflection().onType(jsonAnnotatedClass)); + assertThat(registeredTypes.contains(jsonAnnotatedClass)).isTrue(); } + + assertThat(registeredTypes.contains(TypeReference.of(WatsonxAiChatRequest.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(WatsonxAiChatResponse.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(WatsonxAiChatResults.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(WatsonxAiEmbeddingRequest.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(WatsonxAiEmbeddingResponse.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(WatsonxAiEmbeddingResults.class))).isTrue(); } } diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/aot/ZhiPuAiRuntimeHints.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/aot/ZhiPuAiRuntimeHints.java index 75673f1c445..84f989d0452 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/aot/ZhiPuAiRuntimeHints.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/aot/ZhiPuAiRuntimeHints.java @@ -38,10 +38,7 @@ public class ZhiPuAiRuntimeHints implements RuntimeHintsRegistrar { @Override public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) { var mcs = MemberCategory.values(); - for (var tr : findJsonAnnotatedClassesInPackage(ZhiPuAiApi.class)) { - hints.reflection().registerType(tr, mcs); - } - for (var tr : findJsonAnnotatedClassesInPackage(ZhiPuAiImageApi.class)) { + for (var tr : findJsonAnnotatedClassesInPackage("org.springframework.ai.zhipuai")) { hints.reflection().registerType(tr, mcs); } } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/aot/AiRuntimeHints.java b/spring-ai-model/src/main/java/org/springframework/ai/aot/AiRuntimeHints.java index 286059bd3af..e6b9b91c80e 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/aot/AiRuntimeHints.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/aot/AiRuntimeHints.java @@ -92,7 +92,7 @@ public static Set findClassesInPackage(String packageName, TypeFi .map(bd -> TypeReference.of(Objects.requireNonNull(bd.getBeanClassName())))// .peek(tr -> { if (log.isDebugEnabled()) { - log.debug("registering [" + tr.getName() + ']'); + log.debug("registering [{}]", tr.getName()); } }) .collect(Collectors.toUnmodifiableSet());