Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.ai.anthropic.aot;

import org.springframework.ai.anthropic.AnthropicChatOptions;
import org.springframework.ai.anthropic.api.AnthropicApi;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
Expand All @@ -37,7 +38,8 @@ public class AnthropicRuntimeHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) {
var mcs = MemberCategory.values();
for (var tr : findJsonAnnotatedClassesInPackage(AnthropicApi.class)) {

for (var tr : findJsonAnnotatedClassesInPackage("org.springframework.ai.anthropic")) {
hints.reflection().registerType(tr, mcs);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.ai.anthropic.aot;

import java.util.HashSet;
import java.util.Set;

import org.junit.jupiter.api.Test;
Expand All @@ -26,7 +27,6 @@

import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage;
import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection;

class AnthropicRuntimeHintsTests {

Expand All @@ -36,10 +36,23 @@ void registerHints() {
AnthropicRuntimeHints anthropicRuntimeHints = new AnthropicRuntimeHints();
anthropicRuntimeHints.registerHints(runtimeHints, null);

Set<TypeReference> jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage(AnthropicApi.class);
Set<TypeReference> jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage("org.springframework.ai.anthropic");

Set<TypeReference> registeredTypes = new HashSet<>();
runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType()));

for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) {
assertThat(runtimeHints).matches(reflection().onType(jsonAnnotatedClass));
assertThat(registeredTypes.contains(jsonAnnotatedClass)).isTrue();
}

// Check a few more specific ones
assertThat(registeredTypes.contains(TypeReference.of(AnthropicApi.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(AnthropicApi.Role.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(AnthropicApi.ThinkingType.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(AnthropicApi.EventType.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(AnthropicApi.ContentBlock.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(AnthropicApi.ChatCompletionRequest.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(AnthropicApi.AnthropicMessage.class))).isTrue();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,8 @@ public class BedrockRuntimeHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
var mcs = MemberCategory.values();
for (var tr : findJsonAnnotatedClassesInPackage(AbstractBedrockApi.class)) {
hints.reflection().registerType(tr, mcs);
}

for (var tr : findJsonAnnotatedClassesInPackage(CohereEmbeddingBedrockApi.class)) {
hints.reflection().registerType(tr, mcs);
}
for (var tr : findJsonAnnotatedClassesInPackage(BedrockCohereEmbeddingOptions.class)) {
hints.reflection().registerType(tr, mcs);
}

for (var tr : findJsonAnnotatedClassesInPackage(BedrockTitanEmbeddingOptions.class)) {
hints.reflection().registerType(tr, mcs);
}
for (var tr : findJsonAnnotatedClassesInPackage(TitanEmbeddingBedrockApi.class)) {
for (var tr : findJsonAnnotatedClassesInPackage("org.springframework.ai.bedrock")) {
hints.reflection().registerType(tr, mcs);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,21 @@

package org.springframework.ai.bedrock.aot;

import java.util.Arrays;
import java.util.List;
import java.util.HashSet;
import java.util.Set;

import org.junit.jupiter.api.Test;

import org.springframework.ai.bedrock.api.AbstractBedrockApi;
import org.springframework.ai.bedrock.cohere.BedrockCohereEmbeddingOptions;
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi;
import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingOptions;
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.TypeReference;

import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage;
import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection;

class BedrockRuntimeHintsTests {

Expand All @@ -39,15 +40,22 @@ void registerHints() {
BedrockRuntimeHints bedrockRuntimeHints = new BedrockRuntimeHints();
bedrockRuntimeHints.registerHints(runtimeHints, null);

List<Class> classList = Arrays.asList(CohereEmbeddingBedrockApi.class, TitanEmbeddingBedrockApi.class);
Set<TypeReference> jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage("org.springframework.ai.bedrock");

for (Class aClass : classList) {
Set<TypeReference> jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage(aClass);
for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) {
assertThat(runtimeHints).matches(reflection().onType(jsonAnnotatedClass));
}
Set<TypeReference> registeredTypes = new HashSet<>();
runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType()));

for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) {
assertThat(registeredTypes.contains(jsonAnnotatedClass)).isTrue();
}

// Check a few more specific ones
assertThat(registeredTypes.contains(TypeReference.of(AbstractBedrockApi.AmazonBedrockInvocationMetrics.class)))
.isTrue();
assertThat(registeredTypes.contains(TypeReference.of(CohereEmbeddingBedrockApi.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(BedrockCohereEmbeddingOptions.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(BedrockTitanEmbeddingOptions.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(TitanEmbeddingBedrockApi.class))).isTrue();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ public class MiniMaxRuntimeHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) {
var mcs = MemberCategory.values();
for (var tr : findJsonAnnotatedClassesInPackage(MiniMaxApi.class)) {

for (var tr : findJsonAnnotatedClassesInPackage("org.springframework.ai.minimax")) {
hints.reflection().registerType(tr, mcs);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@

package org.springframework.ai.mistralai.aot;

import org.springframework.ai.mistralai.api.MistralAiApi;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;

import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage;

/**
* The MistralAiRuntimeHints class is responsible for registering runtime hints for
* Mistral AI API classes.
Expand All @@ -35,7 +36,8 @@ public class MistralAiRuntimeHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) {
var mcs = MemberCategory.values();
for (var tr : org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage(MistralAiApi.class)) {

for (var tr : findJsonAnnotatedClassesInPackage("org.springframework.ai.mistralai")) {
hints.reflection().registerType(tr, mcs);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,19 @@

package org.springframework.ai.mistralai.aot;

import java.util.HashSet;
import java.util.Set;

import org.junit.jupiter.api.Test;

import org.springframework.ai.mistralai.MistralAiChatOptions;
import org.springframework.ai.mistralai.MistralAiEmbeddingOptions;
import org.springframework.ai.mistralai.api.MistralAiApi;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.TypeReference;

import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage;
import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection;

class MistralAiRuntimeHintsTests {
Expand All @@ -35,11 +39,22 @@ void registerHints() {
MistralAiRuntimeHints mistralAiRuntimeHints = new MistralAiRuntimeHints();
mistralAiRuntimeHints.registerHints(runtimeHints, null);

Set<TypeReference> jsonAnnotatedClasses = org.springframework.ai.aot.AiRuntimeHints
.findJsonAnnotatedClassesInPackage(MistralAiApi.class);
Set<TypeReference> jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage("org.springframework.ai.mistralai");

Set<TypeReference> registeredTypes = new HashSet<>();
runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType()));

for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) {
assertThat(runtimeHints).matches(reflection().onType(jsonAnnotatedClass));
assertThat(registeredTypes.contains(jsonAnnotatedClass)).isTrue();
}

// Check a few more specific ones
assertThat(registeredTypes.contains(TypeReference.of(MistralAiApi.ChatCompletion.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(MistralAiApi.ChatCompletionChunk.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(MistralAiApi.LogProbs.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(MistralAiApi.ChatCompletionFinishReason.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(MistralAiChatOptions.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(MistralAiEmbeddingOptions.class))).isTrue();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public class MoonshotRuntimeHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) {
var mcs = MemberCategory.values();
for (var tr : findJsonAnnotatedClassesInPackage(MoonshotApi.class)) {
for (var tr : findJsonAnnotatedClassesInPackage("org.springframework.ai.moonshot")) {
hints.reflection().registerType(tr, mcs);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

package org.springframework.ai.moonshot.aot;

import java.util.HashSet;
import java.util.Set;

import org.junit.jupiter.api.Test;

import org.springframework.ai.moonshot.MoonshotChatOptions;
import org.springframework.ai.moonshot.api.MoonshotApi;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.TypeReference;
Expand All @@ -39,10 +41,21 @@ void registerHints() {
MoonshotRuntimeHints moonshotRuntimeHints = new MoonshotRuntimeHints();
moonshotRuntimeHints.registerHints(runtimeHints, null);

Set<TypeReference> jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage(MoonshotApi.class);
Set<TypeReference> jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage("org.springframework.ai.moonshot");

Set<TypeReference> registeredTypes = new HashSet<>();
runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType()));

for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) {
assertThat(runtimeHints).matches(reflection().onType(jsonAnnotatedClass));
assertThat(registeredTypes.contains(jsonAnnotatedClass)).isTrue();
}

// Check a few more specific ones
assertThat(registeredTypes.contains(TypeReference.of(MoonshotApi.ChatCompletion.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(MoonshotApi.ChatCompletionRequest.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(MoonshotApi.ChatCompletionChunk.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(MoonshotApi.Usage.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(MoonshotChatOptions.class))).isTrue();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,7 @@ public class OllamaRuntimeHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
var mcs = MemberCategory.values();
for (var tr : findJsonAnnotatedClassesInPackage(OllamaApi.class)) {
hints.reflection().registerType(tr, mcs);
}
for (var tr : findJsonAnnotatedClassesInPackage(OllamaOptions.class)) {
for (var tr : findJsonAnnotatedClassesInPackage("org.springframework.ai.ollama")) {
hints.reflection().registerType(tr, mcs);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.ai.ollama.aot;

import java.util.HashSet;
import java.util.Set;

import org.junit.jupiter.api.Test;
Expand All @@ -37,15 +38,20 @@ void registerHints() {
OllamaRuntimeHints ollamaRuntimeHints = new OllamaRuntimeHints();
ollamaRuntimeHints.registerHints(runtimeHints, null);

Set<TypeReference> jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage(OllamaApi.class);
for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) {
assertThat(runtimeHints).matches(reflection().onType(jsonAnnotatedClass));
}
Set<TypeReference> jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage("org.springframework.ai.ollama");

Set<TypeReference> registeredTypes = new HashSet<>();
runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType()));

jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage(OllamaOptions.class);
for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) {
assertThat(runtimeHints).matches(reflection().onType(jsonAnnotatedClass));
assertThat(registeredTypes.contains(jsonAnnotatedClass)).isTrue();
}

// Check a few more specific ones
assertThat(registeredTypes.contains(TypeReference.of(OllamaApi.ChatRequest.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(OllamaApi.ChatRequest.Tool.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(OllamaApi.Message.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(OllamaOptions.class))).isTrue();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,10 @@ public class OpenAiRuntimeHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) {
var mcs = MemberCategory.values();
for (var tr : (findJsonAnnotatedClassesInPackage(OpenAiChatOptions.class))) {
hints.reflection().registerType(tr, mcs);
}
for (var tr : (findJsonAnnotatedClassesInPackage(OpenAiApi.class))) {
hints.reflection().registerType(tr, mcs);
}
for (var tr : (findJsonAnnotatedClassesInPackage(OpenAiAudioApi.class))) {
hints.reflection().registerType(tr, mcs);
}
for (var tr : findJsonAnnotatedClassesInPackage(OpenAiImageApi.class)) {

for (var tr : findJsonAnnotatedClassesInPackage("org.springframework.ai.openai")) {
hints.reflection().registerType(tr, mcs);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,20 @@

package org.springframework.ai.openai.aot;

import java.util.HashSet;
import java.util.Set;

import org.junit.jupiter.api.Test;

import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiAudioApi;
import org.springframework.ai.openai.api.OpenAiImageApi;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.TypeReference;

import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage;

class OpenAiRuntimeHintsTests {

Expand All @@ -34,12 +39,26 @@ void registerHints() {
OpenAiRuntimeHints openAiRuntimeHints = new OpenAiRuntimeHints();
openAiRuntimeHints.registerHints(runtimeHints, null);

Set<TypeReference> jsonAnnotatedClasses = org.springframework.ai.aot.AiRuntimeHints
.findJsonAnnotatedClassesInPackage(OpenAiApi.class);
Set<TypeReference> jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage("org.springframework.ai.openai");

Set<TypeReference> registeredTypes = new HashSet<>();
runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType()));

for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) {
assertThat(runtimeHints).matches(org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection()
.onType(jsonAnnotatedClass));
assertThat(registeredTypes.contains(jsonAnnotatedClass)).isTrue();
}

// Check a few more specific ones
assertThat(registeredTypes.contains(TypeReference.of(OpenAiApi.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(OpenAiAudioApi.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(OpenAiAudioApi.TtsModel.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(OpenAiAudioApi.WhisperModel.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(OpenAiImageApi.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(OpenAiApi.ChatCompletionFinishReason.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(OpenAiApi.FunctionTool.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(OpenAiApi.FunctionTool.Function.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(OpenAiApi.OutputModality.class))).isTrue();
assertThat(registeredTypes.contains(TypeReference.of(OpenAiChatOptions.class))).isTrue();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,7 @@ public class QianFanRuntimeHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) {
var mcs = MemberCategory.values();
for (var tr : findJsonAnnotatedClassesInPackage(QianFanApi.class)) {
hints.reflection().registerType(tr, mcs);
}
for (var tr : findJsonAnnotatedClassesInPackage(QianFanImageApi.class)) {
for (var tr : findJsonAnnotatedClassesInPackage("org.springframework.ai.qianfan")) {
hints.reflection().registerType(tr, mcs);
}
}
Expand Down
Loading