Skip to content

Commit 9ba9272

Browse files
committed
Add support for ImportAware callback
This commit adds a way for a BeanFactoryPostProcessor to participate to AOT optimizations by contributing code that replaces its runtime behaviour. ConfigurationClassPostProcessor does implement this new interface and computes a mapping of the ImportAware configuration classes. The mapping is generated for latter reuse by ImportAwareAotBeanPostProcessor. Closes gh-2811
1 parent ec6a19f commit 9ba9272

File tree

7 files changed

+467
-2
lines changed

7 files changed

+467
-2
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Copyright 2002-2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.beans.factory.generator;
18+
19+
import org.springframework.beans.BeansException;
20+
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
21+
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
22+
import org.springframework.lang.Nullable;
23+
24+
/**
25+
* Specialization of {@link BeanFactoryPostProcessor} that contributes bean
26+
* factory optimizations ahead of time, using generated code that replaces
27+
* runtime behavior.
28+
*
29+
* @author Stephane Nicoll
30+
* @since 6.0
31+
*/
32+
@FunctionalInterface
33+
public interface AotContributingBeanFactoryPostProcessor extends BeanFactoryPostProcessor {
34+
35+
/**
36+
* Contribute a {@link BeanFactoryContribution} for the given bean factory,
37+
* if applicable.
38+
* @param beanFactory the bean factory to optimize
39+
* @return the contribution to use or {@code null}
40+
*/
41+
@Nullable
42+
BeanFactoryContribution contribute(ConfigurableListableBeanFactory beanFactory);
43+
44+
@Override
45+
default void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
46+
47+
}
48+
49+
}

spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2021 the original author or authors.
2+
* Copyright 2002-2022 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -18,17 +18,22 @@
1818

1919
import java.util.ArrayList;
2020
import java.util.Arrays;
21+
import java.util.HashMap;
2122
import java.util.HashSet;
2223
import java.util.LinkedHashMap;
2324
import java.util.LinkedHashSet;
2425
import java.util.List;
2526
import java.util.Map;
2627
import java.util.Set;
2728

29+
import javax.lang.model.element.Modifier;
30+
2831
import org.apache.commons.logging.Log;
2932
import org.apache.commons.logging.LogFactory;
3033

3134
import org.springframework.aop.framework.autoproxy.AutoProxyUtils;
35+
import org.springframework.aot.hint.ResourceHints;
36+
import org.springframework.aot.hint.TypeReference;
3237
import org.springframework.beans.PropertyValues;
3338
import org.springframework.beans.factory.BeanClassLoaderAware;
3439
import org.springframework.beans.factory.BeanDefinitionStoreException;
@@ -40,6 +45,9 @@
4045
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
4146
import org.springframework.beans.factory.config.InstantiationAwareBeanPostProcessor;
4247
import org.springframework.beans.factory.config.SingletonBeanRegistry;
48+
import org.springframework.beans.factory.generator.AotContributingBeanFactoryPostProcessor;
49+
import org.springframework.beans.factory.generator.BeanFactoryContribution;
50+
import org.springframework.beans.factory.generator.BeanFactoryInitialization;
4351
import org.springframework.beans.factory.parsing.FailFastProblemReporter;
4452
import org.springframework.beans.factory.parsing.PassThroughSourceExtractor;
4553
import org.springframework.beans.factory.parsing.ProblemReporter;
@@ -65,6 +73,10 @@
6573
import org.springframework.core.type.MethodMetadata;
6674
import org.springframework.core.type.classreading.CachingMetadataReaderFactory;
6775
import org.springframework.core.type.classreading.MetadataReaderFactory;
76+
import org.springframework.javapoet.CodeBlock;
77+
import org.springframework.javapoet.CodeBlock.Builder;
78+
import org.springframework.javapoet.MethodSpec;
79+
import org.springframework.javapoet.ParameterizedTypeName;
6880
import org.springframework.lang.Nullable;
6981
import org.springframework.util.Assert;
7082
import org.springframework.util.ClassUtils;
@@ -89,7 +101,8 @@
89101
* @since 3.0
90102
*/
91103
public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPostProcessor,
92-
PriorityOrdered, ResourceLoaderAware, ApplicationStartupAware, BeanClassLoaderAware, EnvironmentAware {
104+
AotContributingBeanFactoryPostProcessor, PriorityOrdered, ResourceLoaderAware, ApplicationStartupAware,
105+
BeanClassLoaderAware, EnvironmentAware {
93106

94107
/**
95108
* A {@code BeanNameGenerator} using fully qualified class names as default bean names.
@@ -269,6 +282,12 @@ public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory)
269282
beanFactory.addBeanPostProcessor(new ImportAwareBeanPostProcessor(beanFactory));
270283
}
271284

285+
@Override
286+
public BeanFactoryContribution contribute(ConfigurableListableBeanFactory beanFactory) {
287+
return (beanFactory.containsBean(IMPORT_REGISTRY_BEAN_NAME)
288+
? new ImportAwareBeanFactoryConfiguration(beanFactory) : null);
289+
}
290+
272291
/**
273292
* Build and validate a configuration model based on the registry of
274293
* {@link Configuration} classes.
@@ -485,4 +504,55 @@ public Object postProcessBeforeInitialization(Object bean, String beanName) {
485504
}
486505
}
487506

507+
private static final class ImportAwareBeanFactoryConfiguration implements BeanFactoryContribution {
508+
509+
private final ConfigurableListableBeanFactory beanFactory;
510+
511+
private ImportAwareBeanFactoryConfiguration(ConfigurableListableBeanFactory beanFactory) {
512+
this.beanFactory = beanFactory;
513+
}
514+
515+
516+
@Override
517+
public void applyTo(BeanFactoryInitialization initialization) {
518+
Map<String, String> mappings = buildImportAwareMappings();
519+
if (!mappings.isEmpty()) {
520+
MethodSpec method = initialization.generatedTypeContext().getMainGeneratedType()
521+
.addMethod(beanPostProcessorMethod(mappings));
522+
initialization.contribute(code -> code.addStatement("beanFactory.addBeanPostProcessor($N())", method));
523+
ResourceHints resourceHints = initialization.generatedTypeContext().runtimeHints().resources();
524+
mappings.forEach((target, importedFrom) -> resourceHints.registerType(
525+
TypeReference.of(importedFrom)));
526+
}
527+
}
528+
529+
private MethodSpec.Builder beanPostProcessorMethod(Map<String, String> mappings) {
530+
Builder code = CodeBlock.builder();
531+
code.addStatement("$T mappings = new $T<>()", ParameterizedTypeName.get(
532+
Map.class, String.class, String.class), HashMap.class);
533+
mappings.forEach((key, value) -> code.addStatement("mappings.put($S, $S)", key, value));
534+
code.addStatement("return new $T($L)", ImportAwareAotBeanPostProcessor.class, "mappings");
535+
return MethodSpec.methodBuilder("createImportAwareBeanPostProcessor")
536+
.returns(ImportAwareAotBeanPostProcessor.class)
537+
.addModifiers(Modifier.PRIVATE).addCode(code.build());
538+
}
539+
540+
private Map<String, String> buildImportAwareMappings() {
541+
ImportRegistry ir = this.beanFactory.getBean(IMPORT_REGISTRY_BEAN_NAME, ImportRegistry.class);
542+
Map<String, String> mappings = new LinkedHashMap<>();
543+
for (String name : this.beanFactory.getBeanDefinitionNames()) {
544+
Class<?> beanType = this.beanFactory.getType(name);
545+
if (beanType != null && ImportAware.class.isAssignableFrom(beanType)) {
546+
String type = ClassUtils.getUserClass(beanType).getName();
547+
AnnotationMetadata importingClassMetadata = ir.getImportingClassFor(type);
548+
if (importingClassMetadata != null) {
549+
mappings.put(type, importingClassMetadata.getClassName());
550+
}
551+
}
552+
}
553+
return mappings;
554+
}
555+
556+
}
557+
488558
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Copyright 2002-2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.context.annotation;
18+
19+
import java.io.IOException;
20+
import java.util.Map;
21+
22+
import org.springframework.beans.factory.config.BeanPostProcessor;
23+
import org.springframework.core.type.classreading.CachingMetadataReaderFactory;
24+
import org.springframework.core.type.classreading.MetadataReader;
25+
import org.springframework.core.type.classreading.MetadataReaderFactory;
26+
import org.springframework.lang.Nullable;
27+
import org.springframework.util.ClassUtils;
28+
29+
/**
30+
* A {@link BeanPostProcessor} that honours {@link ImportAware} callback using
31+
* a mapping computed at build time.
32+
*
33+
* @author Stephane Nicoll
34+
* @since 6.0
35+
*/
36+
public final class ImportAwareAotBeanPostProcessor implements BeanPostProcessor {
37+
38+
private final MetadataReaderFactory metadataReaderFactory;
39+
40+
private final Map<String, String> importsMapping;
41+
42+
public ImportAwareAotBeanPostProcessor(Map<String, String> importsMapping) {
43+
this.metadataReaderFactory = new CachingMetadataReaderFactory();
44+
this.importsMapping = Map.copyOf(importsMapping);
45+
}
46+
47+
@Override
48+
public Object postProcessBeforeInitialization(Object bean, String beanName) {
49+
if (bean instanceof ImportAware) {
50+
setAnnotationMetadata((ImportAware) bean);
51+
}
52+
return bean;
53+
}
54+
55+
private void setAnnotationMetadata(ImportAware instance) {
56+
String importingClass = getImportingClassFor(instance);
57+
if (importingClass == null) {
58+
return; // import aware configuration class not imported
59+
}
60+
try {
61+
MetadataReader metadataReader = this.metadataReaderFactory.getMetadataReader(importingClass);
62+
instance.setImportMetadata(metadataReader.getAnnotationMetadata());
63+
}
64+
catch (IOException ex) {
65+
throw new IllegalStateException(String.format("Failed to read metadata for '%s'", importingClass), ex);
66+
}
67+
}
68+
69+
@Nullable
70+
private String getImportingClassFor(ImportAware instance) {
71+
String target = ClassUtils.getUserClass(instance).getName();
72+
return this.importsMapping.get(target);
73+
}
74+
75+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Copyright 2002-2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.context.annotation;
18+
19+
import java.util.Map;
20+
21+
import org.junit.jupiter.api.Test;
22+
23+
import org.springframework.core.type.AnnotationMetadata;
24+
25+
import static org.assertj.core.api.Assertions.assertThat;
26+
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
27+
28+
/**
29+
* Tests for {@link ImportAwareAotBeanPostProcessor}.
30+
*
31+
* @author Stephane Nicoll
32+
*/
33+
class ImportAwareAotBeanPostProcessorTests {
34+
35+
@Test
36+
void postProcessOnMatchingCandidate() {
37+
ImportAwareAotBeanPostProcessor postProcessor = new ImportAwareAotBeanPostProcessor(
38+
Map.of(TestImportAware.class.getName(), ImportAwareAotBeanPostProcessorTests.class.getName()));
39+
TestImportAware importAware = new TestImportAware();
40+
postProcessor.postProcessBeforeInitialization(importAware, "test");
41+
assertThat(importAware.importMetadata).isNotNull();
42+
assertThat(importAware.importMetadata.getClassName())
43+
.isEqualTo(ImportAwareAotBeanPostProcessorTests.class.getName());
44+
}
45+
46+
@Test
47+
void postProcessOnMatchingCandidateWithNestedClass() {
48+
ImportAwareAotBeanPostProcessor postProcessor = new ImportAwareAotBeanPostProcessor(
49+
Map.of(TestImportAware.class.getName(), TestImporting.class.getName()));
50+
TestImportAware importAware = new TestImportAware();
51+
postProcessor.postProcessBeforeInitialization(importAware, "test");
52+
assertThat(importAware.importMetadata).isNotNull();
53+
assertThat(importAware.importMetadata.getClassName())
54+
.isEqualTo(TestImporting.class.getName());
55+
}
56+
57+
@Test
58+
void postProcessOnNoCandidateDoesNotInvokeCallback() {
59+
ImportAwareAotBeanPostProcessor postProcessor = new ImportAwareAotBeanPostProcessor(
60+
Map.of(String.class.getName(), ImportAwareAotBeanPostProcessorTests.class.getName()));
61+
TestImportAware importAware = new TestImportAware();
62+
postProcessor.postProcessBeforeInitialization(importAware, "test");
63+
assertThat(importAware.importMetadata).isNull();
64+
}
65+
66+
@Test
67+
void postProcessOnMatchingCandidateWithNoMetadata() {
68+
ImportAwareAotBeanPostProcessor postProcessor = new ImportAwareAotBeanPostProcessor(
69+
Map.of(TestImportAware.class.getName(), "com.example.invalid.DoesNotExist"));
70+
TestImportAware importAware = new TestImportAware();
71+
assertThatIllegalStateException().isThrownBy(() -> postProcessor.postProcessBeforeInitialization(importAware, "test"))
72+
.withMessageContaining("Failed to read metadata for 'com.example.invalid.DoesNotExist'");
73+
}
74+
75+
76+
static class TestImportAware implements ImportAware {
77+
78+
private AnnotationMetadata importMetadata;
79+
80+
@Override
81+
public void setImportMetadata(AnnotationMetadata importMetadata) {
82+
this.importMetadata = importMetadata;
83+
}
84+
}
85+
86+
static class TestImporting {
87+
88+
}
89+
90+
}

0 commit comments

Comments
 (0)