3333import io .zonky .test .db .provider .impl .ZonkyPostgresDatabaseProvider ;
3434import org .apache .commons .lang3 .StringUtils ;
3535import org .flywaydb .core .Flyway ;
36- import org .flywaydb .test .annotation .FlywayTest ;
3736import org .slf4j .Logger ;
3837import org .slf4j .LoggerFactory ;
3938import org .springframework .beans .BeansException ;
6059import org .springframework .util .ObjectUtils ;
6160
6261import javax .sql .DataSource ;
63- import java .lang .reflect .AnnotatedElement ;
6462import java .util .LinkedHashSet ;
6563import java .util .List ;
6664import java .util .Set ;
@@ -116,17 +114,12 @@ public PreloadableEmbeddedPostgresContextCustomizer(Set<AutoConfigureEmbeddedDat
116114 public void customizeContext (ConfigurableApplicationContext context , MergedContextConfiguration mergedConfig ) {
117115 context .addBeanFactoryPostProcessor (new EnvironmentPostProcessor (context .getEnvironment ()));
118116
119- Class <?> testClass = mergedConfig .getTestClass ();
120- FlywayTest [] flywayAnnotations = findFlywayTestAnnotations (testClass );
121-
122117 BeanDefinitionRegistry registry = getBeanDefinitionRegistry (context );
123118 RootBeanDefinition registrarDefinition = new RootBeanDefinition ();
124119
125120 registrarDefinition .setBeanClass (PreloadableEmbeddedPostgresRegistrar .class );
126121 registrarDefinition .getConstructorArgumentValues ()
127122 .addIndexedArgumentValue (0 , databaseAnnotations );
128- registrarDefinition .getConstructorArgumentValues ()
129- .addIndexedArgumentValue (1 , flywayAnnotations );
130123
131124 registry .registerBeanDefinition ("preloadableEmbeddedPostgresRegistrar" , registrarDefinition );
132125 }
@@ -172,13 +165,11 @@ public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory)
172165 protected static class PreloadableEmbeddedPostgresRegistrar implements BeanDefinitionRegistryPostProcessor , EnvironmentAware {
173166
174167 private final Set <AutoConfigureEmbeddedDatabase > databaseAnnotations ;
175- private final FlywayTest [] flywayAnnotations ;
176168
177169 private Environment environment ;
178170
179- public PreloadableEmbeddedPostgresRegistrar (Set <AutoConfigureEmbeddedDatabase > databaseAnnotations , FlywayTest [] flywayAnnotations ) {
171+ public PreloadableEmbeddedPostgresRegistrar (Set <AutoConfigureEmbeddedDatabase > databaseAnnotations ) {
180172 this .databaseAnnotations = databaseAnnotations ;
181- this .flywayAnnotations = flywayAnnotations ;
182173 }
183174
184175 @ Override
@@ -211,7 +202,7 @@ public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) t
211202 DatabaseDescriptor databaseDescriptor = resolveDatabaseDescriptor (environment , databaseAnnotation );
212203
213204 BeanDefinitionHolder dataSourceInfo = getDataSourceBeanDefinition (beanFactory , databaseAnnotation );
214- BeanDefinitionHolder flywayInfo = getFlywayBeanDefinition (beanFactory , flywayAnnotations );
205+ BeanDefinitionHolder flywayInfo = getFlywayBeanDefinition (beanFactory );
215206
216207 RootBeanDefinition dataSourceDefinition = new RootBeanDefinition ();
217208 dataSourceDefinition .setPrimary (dataSourceInfo .getBeanDefinition ().isPrimary ());
@@ -220,7 +211,7 @@ public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) t
220211 dataSourceDefinition .setBeanClass (EmptyEmbeddedPostgresDataSourceFactoryBean .class );
221212 dataSourceDefinition .getConstructorArgumentValues ().addIndexedArgumentValue (0 , databaseDescriptor );
222213 } else {
223- BeanDefinitionHolder contextInfo = getDataSourceContextBeanDefinition (beanFactory , flywayAnnotations );
214+ BeanDefinitionHolder contextInfo = getDataSourceContextBeanDefinition (beanFactory , flywayInfo . getBeanName () );
224215
225216 if (contextInfo == null ) {
226217 RootBeanDefinition dataSourceContextDefinition = new RootBeanDefinition ();
@@ -309,20 +300,8 @@ protected static BeanDefinitionHolder getDataSourceBeanDefinition(ConfigurableLi
309300 throw new IllegalStateException ("No primary DataSource found, embedded version will not be used" );
310301 }
311302
312- protected static BeanDefinitionHolder getDataSourceContextBeanDefinition (ConfigurableListableBeanFactory beanFactory , FlywayTest [] annotations ) {
313- String flywayBeanName = resolveFlywayBeanName (annotations );
314-
315- if (StringUtils .isNotBlank (flywayBeanName )) {
316- String contextBeanName = flywayBeanName + "DataSourceContext" ;
317- if (beanFactory .containsBean (contextBeanName )) {
318- BeanDefinition beanDefinition = beanFactory .getBeanDefinition (contextBeanName );
319- return new BeanDefinitionHolder (beanDefinition , contextBeanName );
320- } else {
321- return null ;
322- }
323- }
324-
325- String [] beanNames = beanFactory .getBeanNamesForType (FlywayDataSourceContext .class );
303+ protected static BeanDefinitionHolder getDataSourceContextBeanDefinition (ConfigurableListableBeanFactory beanFactory , String flywayName ) {
304+ String [] beanNames = beanFactory .getBeanNamesForType (FlywayDataSourceContext .class , true , false );
326305
327306 if (ObjectUtils .isEmpty (beanNames )) {
328307 return null ;
@@ -334,76 +313,26 @@ protected static BeanDefinitionHolder getDataSourceContextBeanDefinition(Configu
334313 return new BeanDefinitionHolder (beanDefinition , beanName );
335314 }
336315
337- for (String beanName : beanNames ) {
338- BeanDefinition beanDefinition = beanFactory .getBeanDefinition (beanName );
339- if (beanDefinition .isPrimary ()) {
340- return new BeanDefinitionHolder (beanDefinition , beanName );
341- }
316+ if (beanFactory .containsBean (flywayName + "DataSourceContext" )) {
317+ BeanDefinition beanDefinition = beanFactory .getBeanDefinition (flywayName + "DataSourceContext" );
318+ return new BeanDefinitionHolder (beanDefinition , flywayName + "DataSourceContext" );
342319 }
343320
344321 return null ;
345322 }
346323
347- protected static BeanDefinitionHolder getFlywayBeanDefinition (ConfigurableListableBeanFactory beanFactory , FlywayTest [] annotations ) {
348- if (annotations .length > 1 ) {
349- return null ; // optimized loading is not supported yet when using multiple flyway test annotations
350- }
351-
352- String flywayBeanName = resolveFlywayBeanName (annotations );
353-
354- if (StringUtils .isNotBlank (flywayBeanName )) {
355- BeanDefinition beanDefinition = beanFactory .getBeanDefinition (flywayBeanName );
356- return new BeanDefinitionHolder (beanDefinition , flywayBeanName );
357- }
358-
359- String [] beanNames = beanFactory .getBeanNamesForType (Flyway .class );
360-
361- if (ObjectUtils .isEmpty (beanNames )) {
362- return null ;
363- }
324+ protected static BeanDefinitionHolder getFlywayBeanDefinition (ConfigurableListableBeanFactory beanFactory ) {
325+ String [] beanNames = beanFactory .getBeanNamesForType (Flyway .class , true , false );
364326
365327 if (beanNames .length == 1 ) {
366328 String beanName = beanNames [0 ];
367329 BeanDefinition beanDefinition = beanFactory .getBeanDefinition (beanName );
368330 return new BeanDefinitionHolder (beanDefinition , beanName );
369331 }
370332
371- for (String beanName : beanNames ) {
372- BeanDefinition beanDefinition = beanFactory .getBeanDefinition (beanName );
373- if (beanDefinition .isPrimary ()) {
374- return new BeanDefinitionHolder (beanDefinition , beanName );
375- }
376- }
377-
378333 return null ;
379334 }
380335
381- protected static FlywayTest [] findFlywayTestAnnotations (AnnotatedElement element ) {
382- if (repeatableAnnotationPresent ) {
383- org .flywaydb .test .annotation .FlywayTests flywayContainerAnnotation =
384- AnnotatedElementUtils .findMergedAnnotation (element , org .flywaydb .test .annotation .FlywayTests .class );
385- if (flywayContainerAnnotation != null ) {
386- return flywayContainerAnnotation .value ();
387- }
388- }
389-
390- FlywayTest flywayAnnotation = AnnotatedElementUtils .findMergedAnnotation (element , FlywayTest .class );
391- if (flywayAnnotation != null ) {
392- return new FlywayTest [] { flywayAnnotation };
393- }
394-
395- return new FlywayTest [0 ];
396- }
397-
398- protected static String resolveFlywayBeanName (FlywayTest [] annotations ) {
399- FlywayTest annotation = annotations .length == 1 ? annotations [0 ] : null ;
400- if (annotation != null && flywayNameAttributePresent ) {
401- return annotation .flywayName ();
402- } else {
403- return null ;
404- }
405- }
406-
407336 protected static <T > Predicate <T > distinctByKey (Function <? super T , ?> keyExtractor ) {
408337 Set <Object > seen = ConcurrentHashMap .newKeySet ();
409338 return t -> seen .add (keyExtractor .apply (t ));
0 commit comments