1919import java .lang .reflect .Modifier ;
2020import java .util .ArrayList ;
2121import java .util .Collection ;
22+ import java .util .Comparator ;
2223import java .util .LinkedHashSet ;
2324import java .util .List ;
2425import java .util .Map ;
2829
2930import org .springframework .aop .framework .AopProxyUtils ;
3031import org .springframework .beans .factory .BeanFactory ;
31- import org .springframework .beans .factory .ListableBeanFactory ;
3232import org .springframework .beans .factory .config .BeanDefinition ;
3333import org .springframework .beans .factory .config .ConfigurableBeanFactory ;
34- import org .springframework .beans .factory .support . BeanDefinitionRegistry ;
34+ import org .springframework .beans .factory .config . ConfigurableListableBeanFactory ;
3535import org .springframework .core .ResolvableType ;
3636import org .springframework .core .annotation .AnnotationAwareOrderComparator ;
3737import org .springframework .lang .Nullable ;
@@ -56,7 +56,6 @@ class EntityCallbackDiscoverer {
5656 private final Map <Class <?>, ResolvableType > entityTypeCache = new ConcurrentReferenceHashMap <>(64 );
5757
5858 @ Nullable private ClassLoader beanClassLoader ;
59- @ Nullable private BeanFactory beanFactory ;
6059
6160 private Object retrievalMutex = this .defaultRetriever ;
6261
@@ -104,12 +103,13 @@ void removeEntityCallback(EntityCallback<?> callback) {
104103 * Return a {@link Collection} of all {@link EntityCallback}s matching the given entity type. Non-matching callbacks
105104 * get excluded early.
106105 *
107- * @param entity the entity to be called back for. Allows for excluding non-matching callbacks early, based on
108- * cached matching information.
106+ * @param entity the entity to be called back for. Allows for excluding non-matching callbacks early, based on cached
107+ * matching information.
109108 * @param callbackType the source callback type.
110109 * @return a {@link Collection} of {@link EntityCallback}s.
111110 * @see EntityCallback
112111 */
112+ @ SuppressWarnings ({ "unchecked" , "rawtypes" })
113113 <T extends S , S > Collection <EntityCallback <S >> getEntityCallbacks (Class <T > entity , ResolvableType callbackType ) {
114114
115115 Class <?> sourceType = entity ;
@@ -121,7 +121,7 @@ <T extends S, S> Collection<EntityCallback<S>> getEntityCallbacks(Class<T> entit
121121 return (Collection ) retriever .getEntityCallbacks ();
122122 }
123123
124- if (this .beanClassLoader == null || ClassUtils .isCacheSafe (entity . getClass () , this .beanClassLoader )
124+ if (this .beanClassLoader == null || ClassUtils .isCacheSafe (entity , this .beanClassLoader )
125125 && (sourceType == null || ClassUtils .isCacheSafe (sourceType , this .beanClassLoader ))) {
126126
127127 // Fully synchronized building and caching of a CallbackRetriever
@@ -163,8 +163,8 @@ ResolvableType resolveDeclaredEntityType(Class<?> callbackType) {
163163 * @param retriever the {@link CallbackRetriever}, if supposed to populate one (for caching purposes)
164164 * @return the pre-filtered list of entity callbacks for the given entity and callback type.
165165 */
166- private Collection <EntityCallback <?>> retrieveEntityCallbacks (ResolvableType entityType ,
167- ResolvableType callbackType , @ Nullable CallbackRetriever retriever ) {
166+ private Collection <EntityCallback <?>> retrieveEntityCallbacks (ResolvableType entityType , ResolvableType callbackType ,
167+ @ Nullable CallbackRetriever retriever ) {
168168
169169 List <EntityCallback <?>> allCallbacks = new ArrayList <>();
170170 Set <EntityCallback <?>> callbacks ;
@@ -198,16 +198,14 @@ private Collection<EntityCallback<?>> retrieveEntityCallbacks(ResolvableType ent
198198 }
199199
200200 /**
201- * Set the {@link BeanFactory} and optionally {@link #setBeanClassLoader(ClassLoader) class loader} if not set.
202- * Pre-loads {@link EntityCallback} beans by scanning the {@link BeanFactory}.
201+ * Set the {@link BeanFactory} and optionally class loader if not set. Pre-loads {@link EntityCallback} beans by
202+ * scanning the {@link BeanFactory}.
203203 *
204204 * @param beanFactory must not be {@literal null}.
205205 * @see org.springframework.beans.factory.BeanFactoryAware#setBeanFactory(BeanFactory)
206206 */
207207 public void setBeanFactory (BeanFactory beanFactory ) {
208208
209- this .beanFactory = beanFactory ;
210-
211209 if (beanFactory instanceof ConfigurableBeanFactory cbf ) {
212210
213211 if (this .beanClassLoader == null ) {
@@ -228,10 +226,8 @@ static Method lookupCallbackMethod(Class<?> callbackType, Class<?> entityType, O
228226
229227 ReflectionUtils .doWithMethods (callbackType , methods ::add , method -> {
230228
231- if (!Modifier .isPublic (method .getModifiers ())
232- || method .getParameterCount () != args .length + 1
233- || method .isBridge ()
234- || ReflectionUtils .isObjectMethod (method )) {
229+ if (!Modifier .isPublic (method .getModifiers ()) || method .getParameterCount () != args .length + 1
230+ || method .isBridge () || ReflectionUtils .isObjectMethod (method )) {
235231 return false ;
236232 }
237233
@@ -242,9 +238,8 @@ static Method lookupCallbackMethod(Class<?> callbackType, Class<?> entityType, O
242238 return methods .iterator ().next ();
243239 }
244240
245- throw new IllegalStateException (
246- "%s does not define a callback method accepting %s and %s additional arguments" .formatted (
247- ClassUtils .getShortName (callbackType ), ClassUtils .getShortName (entityType ), args .length ));
241+ throw new IllegalStateException ("%s does not define a callback method accepting %s and %s additional arguments"
242+ .formatted (ClassUtils .getShortName (callbackType ), ClassUtils .getShortName (entityType ), args .length ));
248243 }
249244
250245 static <T > BiFunction <EntityCallback <T >, T , Object > computeCallbackInvokerFunction (EntityCallback <T > callback ,
@@ -267,10 +262,10 @@ static <T> BiFunction<EntityCallback<T>, T, Object> computeCallbackInvokerFuncti
267262 * Filter a callback early through checking its generically declared entity type before trying to instantiate it.
268263 * <p>
269264 * If this method returns {@literal true} for a given callback as a first pass, the callback instance will get
270- * retrieved and fully evaluated through a {@link #supportsEvent(EntityCallback, ResolvableType, ResolvableType)}
271- * call afterwards.
265+ * retrieved and fully evaluated through a {@link #supportsEvent(EntityCallback, ResolvableType, ResolvableType)} call
266+ * afterwards.
272267 *
273- * @param callback the callback's type as determined by the BeanFactory.
268+ * @param callbackType the callback's type as determined by the BeanFactory.
274269 * @param entityType the entity type to check.
275270 * @return whether the given callback should be included in the candidates for the given callback type.
276271 */
@@ -286,11 +281,9 @@ static boolean supportsEvent(ResolvableType callbackType, ResolvableType entityT
286281 * @param callbackType the source type to check against.
287282 * @return whether the given callback should be included in the candidates for the given callback type.
288283 */
289- static boolean supportsEvent (EntityCallback <?> callback , ResolvableType entityType ,
290- ResolvableType callbackType ) {
284+ static boolean supportsEvent (EntityCallback <?> callback , ResolvableType entityType , ResolvableType callbackType ) {
291285
292- return callback instanceof EntityCallbackAdapter <?> provider
293- ? provider .supports (callbackType , entityType )
286+ return callback instanceof EntityCallbackAdapter <?> provider ? provider .supports (callbackType , entityType )
294287 : callbackType .isInstance (callback ) && supportsEvent (ResolvableType .forInstance (callback ), entityType );
295288 }
296289
@@ -310,13 +303,11 @@ void discoverEntityCallbacks(BeanFactory beanFactory) {
310303
311304 // We need both a ListableBeanFactory and BeanDefinitionRegistry here for advanced inspection.
312305 // If we don't get that, use simple inspection.
313- if (!(beanFactory instanceof ListableBeanFactory && beanFactory instanceof BeanDefinitionRegistry )) {
306+ if (!(beanFactory instanceof ConfigurableListableBeanFactory bf )) {
314307 beanFactory .getBeanProvider (EntityCallback .class ).stream ().forEach (entityCallbacks ::add );
315308 return ;
316309 }
317310
318- var bf = (ListableBeanFactory & BeanDefinitionRegistry ) beanFactory ;
319-
320311 for (var beanName : bf .getBeanNamesForType (EntityCallback .class )) {
321312
322313 EntityCallback <?> bean = EntityCallback .class .cast (bf .getBean (beanName ));
@@ -328,7 +319,7 @@ void discoverEntityCallbacks(BeanFactory beanFactory) {
328319 entityCallbacks .add (bean );
329320 } else {
330321
331- BeanDefinition definition = bf .getBeanDefinition (beanName );
322+ BeanDefinition definition = bf .getMergedBeanDefinition (beanName );
332323 entityCallbacks .add (new EntityCallbackAdapter <>(bean , definition .getResolvableType ()));
333324 }
334325 }
@@ -340,8 +331,8 @@ void discoverEntityCallbacks(BeanFactory beanFactory) {
340331 *
341332 * @author Oliver Drotbohm
342333 */
343- private static record EntityCallbackAdapter <T >(EntityCallback <T > delegate , ResolvableType type )
344- implements EntityCallback <T > {
334+ private record EntityCallbackAdapter <T > (EntityCallback <T > delegate ,
335+ ResolvableType type ) implements EntityCallback <T > {
345336
346337 boolean supports (ResolvableType callbackType , ResolvableType entityType ) {
347338 return callbackType .isInstance (delegate ) && supportsEvent (type , entityType );
@@ -351,15 +342,16 @@ boolean supports(ResolvableType callbackType, ResolvableType entityType) {
351342 /**
352343 * Cache key for {@link EntityCallback}, based on event type and source type.
353344 */
354- private static record CallbackCacheKey (ResolvableType callbackType , @ Nullable Class <?> entityType )
355- implements Comparable <CallbackCacheKey > {
345+ private record CallbackCacheKey (ResolvableType callbackType ,
346+ @ Nullable Class <?> entityType ) implements Comparable <CallbackCacheKey > {
347+
348+ private static final Comparator <CallbackCacheKey > COMPARATOR = Comparators .<CallbackCacheKey > nullsHigh () //
349+ .thenComparing (it -> it .callbackType .toString ()) //
350+ .thenComparing (it -> it .entityType .getName ());
356351
357352 @ Override
358353 public int compareTo (CallbackCacheKey other ) {
359-
360- return Comparators .<CallbackCacheKey > nullsHigh ()
361- .thenComparing (it -> callbackType .toString ())
362- .thenComparing (it -> entityType .getName ()).compare (this , other );
354+ return COMPARATOR .compare (this , other );
363355 }
364356 }
365357
0 commit comments