Skip to content

Commit 455352f

Browse files
committed
TestcontainersBeanRegistrationAotProcessor that replaces InstanceSupplier of Container by a reflection equivalent
1 parent 4718485 commit 455352f

File tree

4 files changed

+153
-2
lines changed

4 files changed

+153
-2
lines changed

spring-boot-project/spring-boot-testcontainers/src/dockerTest/java/org/springframework/boot/testcontainers/ImportTestcontainersTests.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,27 @@
1818

1919
import java.lang.annotation.Retention;
2020
import java.lang.annotation.RetentionPolicy;
21+
import java.util.function.BiConsumer;
2122

2223
import org.junit.jupiter.api.AfterEach;
2324
import org.junit.jupiter.api.Test;
2425
import org.testcontainers.containers.Container;
2526
import org.testcontainers.containers.PostgreSQLContainer;
2627

28+
import org.springframework.aot.test.generate.TestGenerationContext;
2729
import org.springframework.boot.testcontainers.beans.TestcontainerBeanDefinition;
2830
import org.springframework.boot.testcontainers.context.ImportTestcontainers;
31+
import org.springframework.boot.testcontainers.lifecycle.TestcontainersLifecycleApplicationContextInitializer;
2932
import org.springframework.boot.testsupport.container.DisabledIfDockerUnavailable;
3033
import org.springframework.boot.testsupport.container.TestImage;
34+
import org.springframework.context.ApplicationContextInitializer;
3135
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
36+
import org.springframework.context.aot.ApplicationContextAotGenerator;
37+
import org.springframework.context.support.GenericApplicationContext;
38+
import org.springframework.core.test.tools.CompileWithForkedClassLoader;
39+
import org.springframework.core.test.tools.Compiled;
40+
import org.springframework.core.test.tools.TestCompiler;
41+
import org.springframework.javapoet.ClassName;
3242
import org.springframework.test.context.DynamicPropertyRegistry;
3343
import org.springframework.test.context.DynamicPropertySource;
3444

@@ -122,6 +132,34 @@ void importWhenHasBadArgsDynamicPropertySourceMethod() {
122132
.withMessage("@DynamicPropertySource method 'containerProperties' must be static");
123133
}
124134

135+
@Test
136+
@CompileWithForkedClassLoader
137+
void aotContributionRegistersTestcontainers() {
138+
this.applicationContext = new AnnotationConfigApplicationContext();
139+
this.applicationContext.register(ImportWithValue.class);
140+
new TestcontainersLifecycleApplicationContextInitializer().initialize(this.applicationContext);
141+
compile((freshContext, compiled) -> {
142+
PostgreSQLContainer<?> container = freshContext.getBean(PostgreSQLContainer.class);
143+
assertThat(container).isSameAs(ContainerDefinitions.container);
144+
});
145+
}
146+
147+
@SuppressWarnings("unchecked")
148+
private void compile(BiConsumer<GenericApplicationContext, Compiled> result) {
149+
TestGenerationContext generationContext = new TestGenerationContext();
150+
ClassName className = new ApplicationContextAotGenerator().processAheadOfTime(this.applicationContext,
151+
generationContext);
152+
generationContext.writeGeneratedContent();
153+
TestCompiler.forSystem().with(generationContext).compile((compiled) -> {
154+
GenericApplicationContext freshApplicationContext = new GenericApplicationContext();
155+
ApplicationContextInitializer<GenericApplicationContext> initializer = compiled
156+
.getInstance(ApplicationContextInitializer.class, className.toString());
157+
initializer.initialize(freshApplicationContext);
158+
freshApplicationContext.refresh();
159+
result.accept(freshApplicationContext, compiled);
160+
});
161+
}
162+
125163
@ImportTestcontainers
126164
static class ImportWithoutValue {
127165

spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/context/TestcontainerFieldBeanDefinition.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2012-2023 the original author or authors.
2+
* Copyright 2012-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -38,9 +38,10 @@ class TestcontainerFieldBeanDefinition extends RootBeanDefinition implements Tes
3838
TestcontainerFieldBeanDefinition(Field field, Container<?> container) {
3939
this.container = container;
4040
this.annotations = MergedAnnotations.from(field);
41-
this.setBeanClass(container.getClass());
41+
setBeanClass(container.getClass());
4242
setInstanceSupplier(() -> container);
4343
setRole(ROLE_INFRASTRUCTURE);
44+
setAttribute(TestcontainerFieldBeanDefinition.class.getName(), field);
4445
}
4546

4647
@Override
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/*
2+
* Copyright 2012-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.boot.testcontainers.context;
18+
19+
import java.lang.reflect.Field;
20+
21+
import javax.lang.model.element.Modifier;
22+
23+
import org.testcontainers.containers.Container;
24+
25+
import org.springframework.aot.generate.GeneratedMethod;
26+
import org.springframework.aot.generate.GenerationContext;
27+
import org.springframework.beans.factory.aot.BeanRegistrationAotContribution;
28+
import org.springframework.beans.factory.aot.BeanRegistrationAotProcessor;
29+
import org.springframework.beans.factory.aot.BeanRegistrationCode;
30+
import org.springframework.beans.factory.aot.BeanRegistrationCodeFragments;
31+
import org.springframework.beans.factory.aot.BeanRegistrationCodeFragmentsDecorator;
32+
import org.springframework.beans.factory.support.InstanceSupplier;
33+
import org.springframework.beans.factory.support.RegisteredBean;
34+
import org.springframework.beans.factory.support.RootBeanDefinition;
35+
import org.springframework.javapoet.ClassName;
36+
import org.springframework.javapoet.CodeBlock;
37+
import org.springframework.util.Assert;
38+
import org.springframework.util.ClassUtils;
39+
import org.springframework.util.ReflectionUtils;
40+
41+
/**
42+
* {@link BeanRegistrationAotProcessor} that replaces InstanceSupplier of
43+
* {@link Container} by a reflection equivalent.
44+
*
45+
* @author Dmytro Nosan
46+
*/
47+
class TestcontainersBeanRegistrationAotProcessor implements BeanRegistrationAotProcessor {
48+
49+
@Override
50+
public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registeredBean) {
51+
RootBeanDefinition bd = registeredBean.getMergedBeanDefinition();
52+
String attributeName = TestcontainerFieldBeanDefinition.class.getName();
53+
Object field = bd.getAttribute(attributeName);
54+
if (field != null) {
55+
Assert.isInstanceOf(Field.class, field, "BeanDefinition attribute '" + attributeName
56+
+ "' value must be a type of '" + field.getClass().getName() + "'");
57+
return BeanRegistrationAotContribution.withCustomCodeFragments(
58+
(codeFragments) -> new AotContribution(codeFragments, registeredBean, ((Field) field)));
59+
}
60+
return null;
61+
}
62+
63+
static class AotContribution extends BeanRegistrationCodeFragmentsDecorator {
64+
65+
private final RegisteredBean registeredBean;
66+
67+
private final Field field;
68+
69+
AotContribution(BeanRegistrationCodeFragments delegate, RegisteredBean registeredBean, Field field) {
70+
super(delegate);
71+
this.registeredBean = registeredBean;
72+
this.field = field;
73+
}
74+
75+
@Override
76+
public ClassName getTarget(RegisteredBean registeredBean) {
77+
return ClassName.get(this.registeredBean.getBeanClass());
78+
}
79+
80+
@Override
81+
public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext,
82+
BeanRegistrationCode beanRegistrationCode, boolean allowDirectSupplierShortcut) {
83+
Class<?> beanClass = this.registeredBean.getBeanClass();
84+
Class<?> testClass = this.field.getDeclaringClass();
85+
String fieldName = this.field.getName();
86+
GeneratedMethod generatedMethod = beanRegistrationCode.getMethods().add("getInstance", (method) -> {
87+
method.addJavadoc("Get the bean instance for '$L'.", this.registeredBean.getBeanName())
88+
.addModifiers(Modifier.PRIVATE, Modifier.STATIC)
89+
.returns(beanClass)
90+
.addStatement("$T<?> testClass = $T.forName($S, null)", Class.class, ClassUtils.class,
91+
testClass.getName())
92+
.addStatement("$T field = $T.findField(testClass, $S)", Field.class, ReflectionUtils.class,
93+
fieldName)
94+
.addStatement("$T.notNull(field, $S)", Assert.class, "Field '" + fieldName + "' is not found")
95+
.addStatement("$T.makeAccessible(field)", ReflectionUtils.class)
96+
.addStatement("$T container = ($T) $T.getField(field, null)", beanClass, beanClass,
97+
ReflectionUtils.class)
98+
.addStatement("$T.notNull(container, $S)", Assert.class,
99+
"Container field '" + fieldName + "' must not have a null value")
100+
.addStatement("return container")
101+
.addException(ClassNotFoundException.class);
102+
});
103+
return CodeBlock.of("$T.using($T::$L)", InstanceSupplier.class, beanRegistrationCode.getClassName(),
104+
generatedMethod.getName());
105+
}
106+
107+
}
108+
109+
}

spring-boot-project/spring-boot-testcontainers/src/main/resources/META-INF/spring/aot.factories

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,6 @@ org.springframework.boot.testcontainers.service.connection.ConnectionDetailsRegi
33

44
org.springframework.aot.hint.RuntimeHintsRegistrar=\
55
org.springframework.boot.testcontainers.service.connection.ContainerConnectionDetailsFactory.ContainerConnectionDetailsFactoriesRuntimeHints
6+
7+
org.springframework.beans.factory.aot.BeanRegistrationAotProcessor=\
8+
org.springframework.boot.testcontainers.context.TestcontainersBeanRegistrationAotProcessor

0 commit comments

Comments
 (0)