15
15
*/
16
16
package org .springframework .data .aot ;
17
17
18
+ import java .lang .reflect .Executable ;
19
+ import java .lang .reflect .Method ;
18
20
import java .util .List ;
19
21
import java .util .function .BiConsumer ;
20
22
23
+ import javax .lang .model .element .Modifier ;
24
+
25
+ import org .springframework .aot .generate .AccessVisibility ;
26
+ import org .springframework .aot .generate .GeneratedMethod ;
21
27
import org .springframework .aot .generate .GenerationContext ;
22
28
import org .springframework .beans .factory .aot .BeanRegistrationAotContribution ;
23
29
import org .springframework .beans .factory .aot .BeanRegistrationCode ;
30
+ import org .springframework .beans .factory .aot .BeanRegistrationCodeFragments ;
31
+ import org .springframework .beans .factory .support .InstanceSupplier ;
32
+ import org .springframework .beans .factory .support .RegisteredBean ;
24
33
import org .springframework .core .ResolvableType ;
25
34
import org .springframework .data .domain .ManagedTypes ;
35
+ import org .springframework .data .util .Lazy ;
36
+ import org .springframework .javapoet .CodeBlock ;
37
+ import org .springframework .javapoet .MethodSpec .Builder ;
38
+ import org .springframework .javapoet .ParameterizedTypeName ;
26
39
import org .springframework .lang .Nullable ;
40
+ import org .springframework .util .ClassUtils ;
41
+ import org .springframework .util .ObjectUtils ;
42
+ import org .springframework .util .ReflectionUtils ;
27
43
28
44
/**
29
45
* {@link BeanRegistrationAotContribution} used to contribute a {@link ManagedTypes} registration.
46
+ * <p>
47
+ * Will try to resolve bean definition arguments if possible and fall back to resolving the bean from the context if
48
+ * that is not possible. To avoid duplicate invocations of potential scan operations hidden by the {@link ManagedTypes}
49
+ * instance the {@link BeanRegistrationAotContribution} will write custom instantiation code via
50
+ * {@link BeanRegistrationAotContribution#customizeBeanRegistrationCodeFragments(GenerationContext, BeanRegistrationCodeFragments)}.
51
+ * The generated code resolves potential factory methods accepting either a {@link ManagedTypes} instance, or a
52
+ * {@link List} of either {@link Class} or {@link String} (classname) values.
53
+ *
54
+ * <pre>
55
+ * <code>
56
+ * public static InstanceSupplier<ManagedTypes> instance() {
57
+ * return (registeredBean) -> {
58
+ * var types = List.of("com.example.A", "com.example.B");
59
+ * return ManagedTypes.ofStream(types.stream().map(it -> ClassUtils.forName(it, registeredBean.getBeanFactory().getBeanClassLoader())));
60
+ * }
61
+ * }
62
+ * </code>
63
+ * </pre>
30
64
*
31
65
* @author John Blum
66
+ * @author Christoph Strobl
32
67
* @see org.springframework.beans.factory.aot.BeanRegistrationAotContribution
33
68
* @since 3.0.0
34
69
*/
35
- public class ManagedTypesRegistrationAotContribution implements BeanRegistrationAotContribution {
70
+ public class ManagedTypesRegistrationAotContribution implements RegisteredBeanAotContribution {
36
71
37
72
private final AotContext aotContext ;
38
73
private final ManagedTypes managedTypes ;
39
74
private final BiConsumer <ResolvableType , GenerationContext > contributionAction ;
75
+ private final RegisteredBean source ;
40
76
41
77
public ManagedTypesRegistrationAotContribution (AotContext aotContext , @ Nullable ManagedTypes managedTypes ,
42
- BiConsumer <ResolvableType , GenerationContext > contributionAction ) {
78
+ RegisteredBean registeredBean , BiConsumer <ResolvableType , GenerationContext > contributionAction ) {
43
79
44
80
this .aotContext = aotContext ;
45
81
this .managedTypes = managedTypes ;
46
82
this .contributionAction = contributionAction ;
83
+ this .source = registeredBean ;
47
84
}
48
85
49
86
protected AotContext getAotContext () {
@@ -63,4 +100,129 @@ public void applyTo(GenerationContext generationContext, BeanRegistrationCode be
63
100
TypeCollector .inspect (types ).forEach (type -> contributionAction .accept (type , generationContext ));
64
101
}
65
102
}
103
+
104
+ @ Override
105
+ public BeanRegistrationCodeFragments customizeBeanRegistrationCodeFragments (GenerationContext generationContext ,
106
+ BeanRegistrationCodeFragments codeFragments ) {
107
+
108
+ if (managedTypes == null ) {
109
+ return codeFragments ;
110
+ }
111
+
112
+ ManagedTypesInstanceCodeFragment fragment = new ManagedTypesInstanceCodeFragment (getManagedTypes (), source ,
113
+ codeFragments );
114
+ return fragment .canGenerateCode () ? fragment : codeFragments ;
115
+ }
116
+
117
+ @ Override
118
+ public RegisteredBean getSource () {
119
+ return source ;
120
+ }
121
+
122
+ static class ManagedTypesInstanceCodeFragment extends BeanRegistrationCodeFragments {
123
+
124
+ private ManagedTypes sourceTypes ;
125
+ private RegisteredBean source ;
126
+ private Lazy <Method > instanceMethod = Lazy .of (this ::findInstanceFactory );
127
+
128
+ protected ManagedTypesInstanceCodeFragment (ManagedTypes managedTypes , RegisteredBean source ,
129
+ BeanRegistrationCodeFragments codeFragments ) {
130
+
131
+ super (codeFragments );
132
+
133
+ this .sourceTypes = managedTypes ;
134
+ this .source = source ;
135
+ }
136
+
137
+ /**
138
+ * @return {@literal true} if the instance method code can be generated. {@literal false} otherwise.
139
+ */
140
+ boolean canGenerateCode () {
141
+
142
+ if (ObjectUtils .nullSafeEquals (source .getBeanClass (), ManagedTypes .class )) {
143
+ return true ;
144
+ }
145
+ return instanceMethod .getNullable () != null ;
146
+ }
147
+
148
+ @ Override
149
+ public CodeBlock generateInstanceSupplierCode (GenerationContext generationContext ,
150
+ BeanRegistrationCode beanRegistrationCode , Executable constructorOrFactoryMethod ,
151
+ boolean allowDirectSupplierShortcut ) {
152
+
153
+ GeneratedMethod generatedMethod = beanRegistrationCode .getMethods ().add ("Instance" ,
154
+ this ::generateInstanceFactory );
155
+
156
+ return CodeBlock .of ("$T.$L()" , beanRegistrationCode .getClassName (), generatedMethod .getName ());
157
+ }
158
+
159
+ private CodeBlock toCodeBlock (List <Class <?>> values , boolean allPublic ) {
160
+
161
+ if (allPublic ) {
162
+ return CodeBlock .join (values .stream ().map (value -> CodeBlock .of ("$T.class" , value )).toList (), ", " );
163
+ }
164
+ return CodeBlock .join (values .stream ().map (value -> CodeBlock .of ("$S" , value .getName ())).toList (), ", " );
165
+ }
166
+
167
+ private Method findInstanceFactory () {
168
+
169
+ for (Method beanMethod : ReflectionUtils .getDeclaredMethods (source .getBeanClass ())) {
170
+
171
+ if (beanMethod .getParameterCount () == 1 && java .lang .reflect .Modifier .isPublic (beanMethod .getModifiers ())
172
+ && java .lang .reflect .Modifier .isStatic (beanMethod .getModifiers ())) {
173
+ ResolvableType parameterType = ResolvableType .forMethodParameter (beanMethod , 0 , source .getBeanClass ());
174
+ if (parameterType .isAssignableFrom (ResolvableType .forType (List .class ))
175
+ || parameterType .isAssignableFrom (ResolvableType .forType (ManagedTypes .class ))) {
176
+ return beanMethod ;
177
+ }
178
+ }
179
+ }
180
+ return null ;
181
+ }
182
+
183
+ void generateInstanceFactory (Builder method ) {
184
+
185
+ List <Class <?>> sourceTypes = this .sourceTypes .toList ();
186
+ boolean allSourceTypesVisible = sourceTypes .stream ()
187
+ .allMatch (it -> AccessVisibility .PUBLIC .equals (AccessVisibility .forClass (it )));
188
+
189
+ ParameterizedTypeName targetTypeName = ParameterizedTypeName .get (InstanceSupplier .class , source .getBeanClass ());
190
+
191
+ method .addModifiers (Modifier .PRIVATE , Modifier .STATIC );
192
+ method .returns (targetTypeName );
193
+
194
+ CodeBlock .Builder builder = CodeBlock .builder ().add ("return " ).beginControlFlow ("(registeredBean -> " );
195
+
196
+ builder .addStatement ("var types = $T.of($L)" , List .class , toCodeBlock (sourceTypes , allSourceTypesVisible ));
197
+
198
+ if (allSourceTypesVisible ) {
199
+ builder .addStatement ("var managedTypes = $T.fromIterable($L)" , ManagedTypes .class , "types" );
200
+ } else {
201
+ builder .add (CodeBlock .builder ()
202
+ .beginControlFlow ("var managedTypes = $T.fromStream(types.stream().map(it ->" , ManagedTypes .class )
203
+ .beginControlFlow ("try" )
204
+ .addStatement ("return $T.forName(it, registeredBean.getBeanFactory().getBeanClassLoader())" ,
205
+ ClassUtils .class )
206
+ .nextControlFlow ("catch ($T e)" , ClassNotFoundException .class )
207
+ .addStatement ("throw new $T($S, e)" , IllegalArgumentException .class , "Cannot to load type" ).endControlFlow ()
208
+ .endControlFlow ("))" ).build ());
209
+ }
210
+ if (ObjectUtils .nullSafeEquals (source .getBeanClass (), ManagedTypes .class )) {
211
+ builder .add ("return managedTypes" );
212
+ } else {
213
+ Method instanceFactoryMethod = instanceMethod .get ();
214
+ if (ResolvableType .forMethodParameter (instanceFactoryMethod , 0 )
215
+ .isAssignableFrom (ResolvableType .forType (ManagedTypes .class ))) {
216
+ builder .addStatement ("return $T.$L($L)" , instanceFactoryMethod .getDeclaringClass (),
217
+ instanceFactoryMethod .getName (), "managedTypes" );
218
+
219
+ } else {
220
+ builder .addStatement ("return $T.$L($L.toList())" , instanceFactoryMethod .getDeclaringClass (),
221
+ instanceFactoryMethod .getName (), "managedTypes" );
222
+ }
223
+ }
224
+ builder .endControlFlow (")" );
225
+ method .addCode (builder .build ());
226
+ }
227
+ }
66
228
}
0 commit comments