1515 */
1616package org .springframework .data .aot ;
1717
18+ import java .lang .reflect .Executable ;
19+ import java .lang .reflect .Method ;
1820import java .util .List ;
1921import java .util .function .BiConsumer ;
2022
23+ import javax .lang .model .element .Modifier ;
24+
25+ import org .springframework .aot .generate .AccessVisibility ;
26+ import org .springframework .aot .generate .GeneratedMethod ;
2127import org .springframework .aot .generate .GenerationContext ;
2228import org .springframework .beans .factory .aot .BeanRegistrationAotContribution ;
2329import org .springframework .beans .factory .aot .BeanRegistrationCode ;
30+ import org .springframework .beans .factory .aot .BeanRegistrationCodeFragments ;
31+ import org .springframework .beans .factory .support .InstanceSupplier ;
32+ import org .springframework .beans .factory .support .RegisteredBean ;
2433import org .springframework .core .ResolvableType ;
2534import org .springframework .data .domain .ManagedTypes ;
35+ import org .springframework .data .util .Lazy ;
36+ import org .springframework .javapoet .CodeBlock ;
37+ import org .springframework .javapoet .MethodSpec .Builder ;
38+ import org .springframework .javapoet .ParameterizedTypeName ;
2639import org .springframework .lang .Nullable ;
40+ import org .springframework .util .ClassUtils ;
41+ import org .springframework .util .ObjectUtils ;
42+ import org .springframework .util .ReflectionUtils ;
2743
2844/**
2945 * {@link BeanRegistrationAotContribution} used to contribute a {@link ManagedTypes} registration.
46+ * <p>
47+ * Will try to resolve bean definition arguments if possible and fall back to resolving the bean from the context if
48+ * that is not possible. To avoid duplicate invocations of potential scan operations hidden by the {@link ManagedTypes}
49+ * instance the {@link BeanRegistrationAotContribution} will write custom instantiation code via
50+ * {@link BeanRegistrationAotContribution#customizeBeanRegistrationCodeFragments(GenerationContext, BeanRegistrationCodeFragments)}.
51+ * The generated code resolves potential factory methods accepting either a {@link ManagedTypes} instance, or a
52+ * {@link List} of either {@link Class} or {@link String} (classname) values.
53+ *
54+ * <pre>
55+ * <code>
56+ * public static InstanceSupplier<ManagedTypes> instance() {
57+ * return (registeredBean) -> {
58+ * var types = List.of("com.example.A", "com.example.B");
59+ * return ManagedTypes.ofStream(types.stream().map(it -> ClassUtils.forName(it, registeredBean.getBeanFactory().getBeanClassLoader())));
60+ * }
61+ * }
62+ * </code>
63+ * </pre>
3064 *
3165 * @author John Blum
66+ * @author Christoph Strobl
3267 * @see org.springframework.beans.factory.aot.BeanRegistrationAotContribution
3368 * @since 3.0.0
3469 */
35- public class ManagedTypesRegistrationAotContribution implements BeanRegistrationAotContribution {
70+ public class ManagedTypesRegistrationAotContribution implements RegisteredBeanAotContribution {
3671
3772 private final AotContext aotContext ;
3873 private final ManagedTypes managedTypes ;
3974 private final BiConsumer <ResolvableType , GenerationContext > contributionAction ;
75+ private final RegisteredBean source ;
4076
4177 public ManagedTypesRegistrationAotContribution (AotContext aotContext , @ Nullable ManagedTypes managedTypes ,
42- BiConsumer <ResolvableType , GenerationContext > contributionAction ) {
78+ RegisteredBean registeredBean , BiConsumer <ResolvableType , GenerationContext > contributionAction ) {
4379
4480 this .aotContext = aotContext ;
4581 this .managedTypes = managedTypes ;
4682 this .contributionAction = contributionAction ;
83+ this .source = registeredBean ;
4784 }
4885
4986 protected AotContext getAotContext () {
@@ -63,4 +100,129 @@ public void applyTo(GenerationContext generationContext, BeanRegistrationCode be
63100 TypeCollector .inspect (types ).forEach (type -> contributionAction .accept (type , generationContext ));
64101 }
65102 }
103+
104+ @ Override
105+ public BeanRegistrationCodeFragments customizeBeanRegistrationCodeFragments (GenerationContext generationContext ,
106+ BeanRegistrationCodeFragments codeFragments ) {
107+
108+ if (managedTypes == null ) {
109+ return codeFragments ;
110+ }
111+
112+ ManagedTypesInstanceCodeFragment fragment = new ManagedTypesInstanceCodeFragment (getManagedTypes (), source ,
113+ codeFragments );
114+ return fragment .canGenerateCode () ? fragment : codeFragments ;
115+ }
116+
117+ @ Override
118+ public RegisteredBean getSource () {
119+ return source ;
120+ }
121+
122+ static class ManagedTypesInstanceCodeFragment extends BeanRegistrationCodeFragments {
123+
124+ private ManagedTypes sourceTypes ;
125+ private RegisteredBean source ;
126+ private Lazy <Method > instanceMethod = Lazy .of (this ::findInstanceFactory );
127+
128+ protected ManagedTypesInstanceCodeFragment (ManagedTypes managedTypes , RegisteredBean source ,
129+ BeanRegistrationCodeFragments codeFragments ) {
130+
131+ super (codeFragments );
132+
133+ this .sourceTypes = managedTypes ;
134+ this .source = source ;
135+ }
136+
137+ /**
138+ * @return {@literal true} if the instance method code can be generated. {@literal false} otherwise.
139+ */
140+ boolean canGenerateCode () {
141+
142+ if (ObjectUtils .nullSafeEquals (source .getBeanClass (), ManagedTypes .class )) {
143+ return true ;
144+ }
145+ return instanceMethod .getNullable () != null ;
146+ }
147+
148+ @ Override
149+ public CodeBlock generateInstanceSupplierCode (GenerationContext generationContext ,
150+ BeanRegistrationCode beanRegistrationCode , Executable constructorOrFactoryMethod ,
151+ boolean allowDirectSupplierShortcut ) {
152+
153+ GeneratedMethod generatedMethod = beanRegistrationCode .getMethods ().add ("Instance" ,
154+ this ::generateInstanceFactory );
155+
156+ return CodeBlock .of ("$T.$L()" , beanRegistrationCode .getClassName (), generatedMethod .getName ());
157+ }
158+
159+ private CodeBlock toCodeBlock (List <Class <?>> values , boolean allPublic ) {
160+
161+ if (allPublic ) {
162+ return CodeBlock .join (values .stream ().map (value -> CodeBlock .of ("$T.class" , value )).toList (), ", " );
163+ }
164+ return CodeBlock .join (values .stream ().map (value -> CodeBlock .of ("$S" , value .getName ())).toList (), ", " );
165+ }
166+
167+ private Method findInstanceFactory () {
168+
169+ for (Method beanMethod : ReflectionUtils .getDeclaredMethods (source .getBeanClass ())) {
170+
171+ if (beanMethod .getParameterCount () == 1 && java .lang .reflect .Modifier .isPublic (beanMethod .getModifiers ())
172+ && java .lang .reflect .Modifier .isStatic (beanMethod .getModifiers ())) {
173+ ResolvableType parameterType = ResolvableType .forMethodParameter (beanMethod , 0 , source .getBeanClass ());
174+ if (parameterType .isAssignableFrom (ResolvableType .forType (List .class ))
175+ || parameterType .isAssignableFrom (ResolvableType .forType (ManagedTypes .class ))) {
176+ return beanMethod ;
177+ }
178+ }
179+ }
180+ return null ;
181+ }
182+
183+ void generateInstanceFactory (Builder method ) {
184+
185+ List <Class <?>> sourceTypes = this .sourceTypes .toList ();
186+ boolean allSourceTypesVisible = sourceTypes .stream ()
187+ .allMatch (it -> AccessVisibility .PUBLIC .equals (AccessVisibility .forClass (it )));
188+
189+ ParameterizedTypeName targetTypeName = ParameterizedTypeName .get (InstanceSupplier .class , source .getBeanClass ());
190+
191+ method .addModifiers (Modifier .PRIVATE , Modifier .STATIC );
192+ method .returns (targetTypeName );
193+
194+ CodeBlock .Builder builder = CodeBlock .builder ().add ("return " ).beginControlFlow ("(registeredBean -> " );
195+
196+ builder .addStatement ("var types = $T.of($L)" , List .class , toCodeBlock (sourceTypes , allSourceTypesVisible ));
197+
198+ if (allSourceTypesVisible ) {
199+ builder .addStatement ("var managedTypes = $T.fromIterable($L)" , ManagedTypes .class , "types" );
200+ } else {
201+ builder .add (CodeBlock .builder ()
202+ .beginControlFlow ("var managedTypes = $T.fromStream(types.stream().map(it ->" , ManagedTypes .class )
203+ .beginControlFlow ("try" )
204+ .addStatement ("return $T.forName(it, registeredBean.getBeanFactory().getBeanClassLoader())" ,
205+ ClassUtils .class )
206+ .nextControlFlow ("catch ($T e)" , ClassNotFoundException .class )
207+ .addStatement ("throw new $T($S, e)" , IllegalArgumentException .class , "Cannot to load type" ).endControlFlow ()
208+ .endControlFlow ("))" ).build ());
209+ }
210+ if (ObjectUtils .nullSafeEquals (source .getBeanClass (), ManagedTypes .class )) {
211+ builder .add ("return managedTypes" );
212+ } else {
213+ Method instanceFactoryMethod = instanceMethod .get ();
214+ if (ResolvableType .forMethodParameter (instanceFactoryMethod , 0 )
215+ .isAssignableFrom (ResolvableType .forType (ManagedTypes .class ))) {
216+ builder .addStatement ("return $T.$L($L)" , instanceFactoryMethod .getDeclaringClass (),
217+ instanceFactoryMethod .getName (), "managedTypes" );
218+
219+ } else {
220+ builder .addStatement ("return $T.$L($L.toList())" , instanceFactoryMethod .getDeclaringClass (),
221+ instanceFactoryMethod .getName (), "managedTypes" );
222+ }
223+ }
224+ builder .endControlFlow (")" );
225+ method .addCode (builder .build ());
226+ }
227+ }
66228}
0 commit comments