Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.springframework.core.io.support.SpringFactoriesLoader;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatNoException;

/**
* @author Jonathan Leijendekker
Expand Down Expand Up @@ -72,6 +73,12 @@ void dataSourceHasHints() {
assertThat(RuntimeHintsPredicates.reflection().onType(DataSource.class)).accepts(this.hints);
}

@Test
void registerHintsWithNullClassLoader() {
assertThatNoException()
.isThrownBy(() -> this.jdbcChatMemoryRepositoryRuntimeHints.registerHints(this.hints, null));
}

private static Stream<String> getSchemaFileNames() throws IOException {
var resources = new PathMatchingResourcePatternResolver()
.getResources("classpath*:org/springframework/ai/chat/memory/repository/jdbc/schema-*.sql");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
import org.springframework.aot.hint.TypeReference;
import org.springframework.util.Assert;

import static org.assertj.core.api.Assertions.assertThat;

class AiRuntimeHintsTests {

@Test
void discoverRelevantClasses() throws Exception {
void discoverRelevantClasses() {
var classes = AiRuntimeHints.findJsonAnnotatedClassesInPackage(TestApi.class);
var included = Set.of(TestApi.Bar.class, TestApi.Foo.class)
.stream()
Expand All @@ -40,6 +42,24 @@ void discoverRelevantClasses() throws Exception {
Assert.state(classes.containsAll(included), "there should be all of the enumerated classes. ");
}

@Test
void verifyRecordWithJsonPropertyIncluded() {
var classes = AiRuntimeHints.findJsonAnnotatedClassesInPackage(TestApi.class);

// Foo record should be included due to @JsonProperty on parameter
var recordClass = TypeReference.of(TestApi.Foo.class.getName());
assertThat(classes).contains(recordClass);
}

@Test
void verifyEnumWithJsonIncludeAnnotation() {
var classes = AiRuntimeHints.findJsonAnnotatedClassesInPackage(TestApi.class);

// Bar enum should be included due to @JsonInclude
var enumClass = TypeReference.of(TestApi.Bar.class.getName());
assertThat(classes).contains(enumClass);
}

@JsonInclude
static class TestApi {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.springframework.aot.hint.RuntimeHints;

import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection;

/**
Expand All @@ -37,4 +38,13 @@ void registerHints() {
assertThat(runtimeHints).matches(reflection().onType(DefaultToolCallResultConverter.class));
}

@Test
void registerHintsWithNullClassLoader() {
RuntimeHints runtimeHints = new RuntimeHints();
ToolRuntimeHints toolRuntimeHints = new ToolRuntimeHints();

// Should not throw exception with null ClassLoader
assertThatCode(() -> toolRuntimeHints.registerHints(runtimeHints, null)).doesNotThrowAnyException();
}

}