1616package org .springframework .data .aot ;
1717
1818import java .util .Collections ;
19- import java .util .function .Function ;
2019import java .util .function .Supplier ;
2120
2221import org .apache .commons .logging .Log ;
4140 *
4241 * @author Christoph Strobl
4342 * @author John Blum
44- * @see org.springframework.beans.factory.config.ConfigurableListableBeanFactory
45- * @see org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution
46- * @see org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor
4743 * @since 3.0
4844 */
49- public class SpringDataBeanFactoryInitializationAotProcessor implements BeanFactoryInitializationAotProcessor {
45+ public class ManagedTypesBeanFactoryInitializationAotProcessor implements BeanFactoryInitializationAotProcessor {
5046
5147 private static final Log logger = LogFactory .getLog (BeanFactoryInitializationAotProcessor .class );
5248
53- private static final Function <Object , Object > arrayToListFunction = target ->
54- ObjectUtils .isArray (target ) ? CollectionUtils .arrayToList (target ) : target ;
55-
56- private static final Function <Object , Object > asSingletonSetFunction = target ->
57- !(target instanceof Iterable <?>) ? Collections .singleton (target ) : target ;
58-
59- private static final Function <Object , Object > constructorArgumentFunction =
60- arrayToListFunction .andThen (asSingletonSetFunction );
61-
6249 @ Nullable
6350 @ Override
6451 public BeanFactoryInitializationAotContribution processAheadOfTime (ConfigurableListableBeanFactory beanFactory ) {
@@ -70,42 +57,61 @@ public BeanFactoryInitializationAotContribution processAheadOfTime(ConfigurableL
7057 private void processManagedTypes (ConfigurableListableBeanFactory beanFactory ) {
7158
7259 if (beanFactory instanceof BeanDefinitionRegistry registry ) {
60+
7361 for (String beanName : beanFactory .getBeanNamesForType (ManagedTypes .class )) {
62+ postProcessManagedTypes (beanFactory , registry , beanName );
63+ }
64+ }
65+ }
7466
75- BeanDefinition beanDefinition = beanFactory .getBeanDefinition (beanName );
67+ private void postProcessManagedTypes (ConfigurableListableBeanFactory beanFactory , BeanDefinitionRegistry registry ,
68+ String beanName ) {
7669
77- if ( hasConstructorArguments ( beanDefinition )) {
70+ BeanDefinition beanDefinition = beanFactory . getBeanDefinition ( beanName );
7871
79- ValueHolder argumentValue = beanDefinition .getConstructorArgumentValues ()
80- .getArgumentValue (0 , null , null , null );
72+ if (hasConstructorArguments (beanDefinition )) {
8173
82- if ( argumentValue . getValue () instanceof Supplier supplier ) {
74+ ValueHolder argumentValue = beanDefinition . getConstructorArgumentValues (). getArgumentValue ( 0 , null , null , null );
8375
84- if (logger .isDebugEnabled ()) {
85- logger .info (String .format ("Replacing ManagedType bean definition %s." , beanName ));
86- }
76+ if (argumentValue .getValue ()instanceof Supplier supplier ) {
8777
88- Object value = constructorArgumentFunction .apply (supplier .get ());
78+ if (logger .isDebugEnabled ()) {
79+ logger .info (String .format ("Replacing ManagedType bean definition %s." , beanName ));
80+ }
8981
90- BeanDefinition beanDefinitionReplacement = newManagedTypeBeanDefinition ( value );
82+ Object value = potentiallyWrapToIterable ( supplier . get () );
9183
92- registry .removeBeanDefinition (beanName );
93- registry .registerBeanDefinition (beanName , beanDefinitionReplacement );
94- }
95- }
84+ BeanDefinition beanDefinitionReplacement = newManagedTypeBeanDefinition (beanDefinition .getBeanClassName (),
85+ value );
86+
87+ registry .removeBeanDefinition (beanName );
88+ registry .registerBeanDefinition (beanName , beanDefinitionReplacement );
9689 }
9790 }
9891 }
9992
93+ private static Object potentiallyWrapToIterable (Object value ) {
94+
95+ if (ObjectUtils .isArray (value )) {
96+ return CollectionUtils .arrayToList (value );
97+ }
98+
99+ if (value instanceof Iterable <?>) {
100+ return value ;
101+ }
102+
103+ return Collections .singleton (value );
104+ }
105+
100106 private boolean hasConstructorArguments (BeanDefinition beanDefinition ) {
101107 return !beanDefinition .getConstructorArgumentValues ().isEmpty ();
102108 }
103109
104- private BeanDefinition newManagedTypeBeanDefinition (Object constructorArgumentValue ) {
110+ private BeanDefinition newManagedTypeBeanDefinition (String managedTypesClassName , Object constructorArgumentValue ) {
105111
106- return BeanDefinitionBuilder .rootBeanDefinition (ManagedTypes . class )
107- .setFactoryMethod ("fromIterable" )
108- .addConstructorArgValue (constructorArgumentValue )
109- .getBeanDefinition ();
112+ return BeanDefinitionBuilder .rootBeanDefinition (managedTypesClassName ) //
113+ .setFactoryMethod ("fromIterable" ) //
114+ .addConstructorArgValue (constructorArgumentValue ) //
115+ .getBeanDefinition ();
110116 }
111117}
0 commit comments