Skip to content

Commit c571a35

Browse files
committed
#43 fix database consistency when multiple flyway beans are used
1 parent 575de8b commit c571a35

File tree

5 files changed

+160
-107
lines changed

5 files changed

+160
-107
lines changed

embedded-database-spring-test/src/main/java/io/zonky/test/db/postgres/EmbeddedPostgresContextCustomizerFactory.java

Lines changed: 10 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
import io.zonky.test.db.provider.impl.ZonkyPostgresDatabaseProvider;
3434
import org.apache.commons.lang3.StringUtils;
3535
import org.flywaydb.core.Flyway;
36-
import org.flywaydb.test.annotation.FlywayTest;
3736
import org.slf4j.Logger;
3837
import org.slf4j.LoggerFactory;
3938
import org.springframework.beans.BeansException;
@@ -60,7 +59,6 @@
6059
import org.springframework.util.ObjectUtils;
6160

6261
import javax.sql.DataSource;
63-
import java.lang.reflect.AnnotatedElement;
6462
import java.util.LinkedHashSet;
6563
import java.util.List;
6664
import java.util.Set;
@@ -116,17 +114,12 @@ public PreloadableEmbeddedPostgresContextCustomizer(Set<AutoConfigureEmbeddedDat
116114
public void customizeContext(ConfigurableApplicationContext context, MergedContextConfiguration mergedConfig) {
117115
context.addBeanFactoryPostProcessor(new EnvironmentPostProcessor(context.getEnvironment()));
118116

119-
Class<?> testClass = mergedConfig.getTestClass();
120-
FlywayTest[] flywayAnnotations = findFlywayTestAnnotations(testClass);
121-
122117
BeanDefinitionRegistry registry = getBeanDefinitionRegistry(context);
123118
RootBeanDefinition registrarDefinition = new RootBeanDefinition();
124119

125120
registrarDefinition.setBeanClass(PreloadableEmbeddedPostgresRegistrar.class);
126121
registrarDefinition.getConstructorArgumentValues()
127122
.addIndexedArgumentValue(0, databaseAnnotations);
128-
registrarDefinition.getConstructorArgumentValues()
129-
.addIndexedArgumentValue(1, flywayAnnotations);
130123

131124
registry.registerBeanDefinition("preloadableEmbeddedPostgresRegistrar", registrarDefinition);
132125
}
@@ -172,13 +165,11 @@ public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory)
172165
protected static class PreloadableEmbeddedPostgresRegistrar implements BeanDefinitionRegistryPostProcessor, EnvironmentAware {
173166

174167
private final Set<AutoConfigureEmbeddedDatabase> databaseAnnotations;
175-
private final FlywayTest[] flywayAnnotations;
176168

177169
private Environment environment;
178170

179-
public PreloadableEmbeddedPostgresRegistrar(Set<AutoConfigureEmbeddedDatabase> databaseAnnotations, FlywayTest[] flywayAnnotations) {
171+
public PreloadableEmbeddedPostgresRegistrar(Set<AutoConfigureEmbeddedDatabase> databaseAnnotations) {
180172
this.databaseAnnotations = databaseAnnotations;
181-
this.flywayAnnotations = flywayAnnotations;
182173
}
183174

184175
@Override
@@ -211,7 +202,7 @@ public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) t
211202
DatabaseDescriptor databaseDescriptor = resolveDatabaseDescriptor(environment, databaseAnnotation);
212203

213204
BeanDefinitionHolder dataSourceInfo = getDataSourceBeanDefinition(beanFactory, databaseAnnotation);
214-
BeanDefinitionHolder flywayInfo = getFlywayBeanDefinition(beanFactory, flywayAnnotations);
205+
BeanDefinitionHolder flywayInfo = getFlywayBeanDefinition(beanFactory);
215206

216207
RootBeanDefinition dataSourceDefinition = new RootBeanDefinition();
217208
dataSourceDefinition.setPrimary(dataSourceInfo.getBeanDefinition().isPrimary());
@@ -220,7 +211,7 @@ public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) t
220211
dataSourceDefinition.setBeanClass(EmptyEmbeddedPostgresDataSourceFactoryBean.class);
221212
dataSourceDefinition.getConstructorArgumentValues().addIndexedArgumentValue(0, databaseDescriptor);
222213
} else {
223-
BeanDefinitionHolder contextInfo = getDataSourceContextBeanDefinition(beanFactory, flywayAnnotations);
214+
BeanDefinitionHolder contextInfo = getDataSourceContextBeanDefinition(beanFactory, flywayInfo.getBeanName());
224215

225216
if (contextInfo == null) {
226217
RootBeanDefinition dataSourceContextDefinition = new RootBeanDefinition();
@@ -309,20 +300,8 @@ protected static BeanDefinitionHolder getDataSourceBeanDefinition(ConfigurableLi
309300
throw new IllegalStateException("No primary DataSource found, embedded version will not be used");
310301
}
311302

312-
protected static BeanDefinitionHolder getDataSourceContextBeanDefinition(ConfigurableListableBeanFactory beanFactory, FlywayTest[] annotations) {
313-
String flywayBeanName = resolveFlywayBeanName(annotations);
314-
315-
if (StringUtils.isNotBlank(flywayBeanName)) {
316-
String contextBeanName = flywayBeanName + "DataSourceContext";
317-
if (beanFactory.containsBean(contextBeanName)) {
318-
BeanDefinition beanDefinition = beanFactory.getBeanDefinition(contextBeanName);
319-
return new BeanDefinitionHolder(beanDefinition, contextBeanName);
320-
} else {
321-
return null;
322-
}
323-
}
324-
325-
String[] beanNames = beanFactory.getBeanNamesForType(FlywayDataSourceContext.class);
303+
protected static BeanDefinitionHolder getDataSourceContextBeanDefinition(ConfigurableListableBeanFactory beanFactory, String flywayName) {
304+
String[] beanNames = beanFactory.getBeanNamesForType(FlywayDataSourceContext.class, true, false);
326305

327306
if (ObjectUtils.isEmpty(beanNames)) {
328307
return null;
@@ -334,76 +313,26 @@ protected static BeanDefinitionHolder getDataSourceContextBeanDefinition(Configu
334313
return new BeanDefinitionHolder(beanDefinition, beanName);
335314
}
336315

337-
for (String beanName : beanNames) {
338-
BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanName);
339-
if (beanDefinition.isPrimary()) {
340-
return new BeanDefinitionHolder(beanDefinition, beanName);
341-
}
316+
if (beanFactory.containsBean(flywayName + "DataSourceContext")) {
317+
BeanDefinition beanDefinition = beanFactory.getBeanDefinition(flywayName + "DataSourceContext");
318+
return new BeanDefinitionHolder(beanDefinition, flywayName + "DataSourceContext");
342319
}
343320

344321
return null;
345322
}
346323

347-
protected static BeanDefinitionHolder getFlywayBeanDefinition(ConfigurableListableBeanFactory beanFactory, FlywayTest[] annotations) {
348-
if (annotations.length > 1) {
349-
return null; // optimized loading is not supported yet when using multiple flyway test annotations
350-
}
351-
352-
String flywayBeanName = resolveFlywayBeanName(annotations);
353-
354-
if (StringUtils.isNotBlank(flywayBeanName)) {
355-
BeanDefinition beanDefinition = beanFactory.getBeanDefinition(flywayBeanName);
356-
return new BeanDefinitionHolder(beanDefinition, flywayBeanName);
357-
}
358-
359-
String[] beanNames = beanFactory.getBeanNamesForType(Flyway.class);
360-
361-
if (ObjectUtils.isEmpty(beanNames)) {
362-
return null;
363-
}
324+
protected static BeanDefinitionHolder getFlywayBeanDefinition(ConfigurableListableBeanFactory beanFactory) {
325+
String[] beanNames = beanFactory.getBeanNamesForType(Flyway.class, true, false);
364326

365327
if (beanNames.length == 1) {
366328
String beanName = beanNames[0];
367329
BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanName);
368330
return new BeanDefinitionHolder(beanDefinition, beanName);
369331
}
370332

371-
for (String beanName : beanNames) {
372-
BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanName);
373-
if (beanDefinition.isPrimary()) {
374-
return new BeanDefinitionHolder(beanDefinition, beanName);
375-
}
376-
}
377-
378333
return null;
379334
}
380335

381-
protected static FlywayTest[] findFlywayTestAnnotations(AnnotatedElement element) {
382-
if (repeatableAnnotationPresent) {
383-
org.flywaydb.test.annotation.FlywayTests flywayContainerAnnotation =
384-
AnnotatedElementUtils.findMergedAnnotation(element, org.flywaydb.test.annotation.FlywayTests.class);
385-
if (flywayContainerAnnotation != null) {
386-
return flywayContainerAnnotation.value();
387-
}
388-
}
389-
390-
FlywayTest flywayAnnotation = AnnotatedElementUtils.findMergedAnnotation(element, FlywayTest.class);
391-
if (flywayAnnotation != null) {
392-
return new FlywayTest[] { flywayAnnotation };
393-
}
394-
395-
return new FlywayTest[0];
396-
}
397-
398-
protected static String resolveFlywayBeanName(FlywayTest[] annotations) {
399-
FlywayTest annotation = annotations.length == 1 ? annotations[0] : null;
400-
if (annotation != null && flywayNameAttributePresent) {
401-
return annotation.flywayName();
402-
} else {
403-
return null;
404-
}
405-
}
406-
407336
protected static <T> Predicate<T> distinctByKey(Function<? super T, ?> keyExtractor) {
408337
Set<Object> seen = ConcurrentHashMap.newKeySet();
409338
return t -> seen.add(keyExtractor.apply(t));

embedded-database-spring-test/src/test/java/io/zonky/test/db/MultipleFlywayBeansClassLevelIntegrationTest.java

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import org.springframework.beans.factory.annotation.Autowired;
2626
import org.springframework.context.annotation.Bean;
2727
import org.springframework.context.annotation.Configuration;
28+
import org.springframework.context.annotation.DependsOn;
29+
import org.springframework.context.annotation.Primary;
2830
import org.springframework.jdbc.core.JdbcTemplate;
2931
import org.springframework.test.context.ContextConfiguration;
3032
import org.springframework.test.context.junit4.SpringRunner;
@@ -39,16 +41,17 @@
3941
@RunWith(SpringRunner.class)
4042
@Category(MultiFlywayIntegrationTests.class)
4143
@FlywayTest(flywayName = "flyway1")
44+
@FlywayTest(flywayName = "flyway2")
4245
@FlywayTest(flywayName = "flyway3", invokeCleanDB = false)
4346
@AutoConfigureEmbeddedDatabase(beanName = "dataSource")
4447
@ContextConfiguration
4548
public class MultipleFlywayBeansClassLevelIntegrationTest {
4649

47-
private static final String SQL_SELECT_PERSONS = "select * from test.person";
48-
4950
@Configuration
5051
static class Config {
5152

53+
@Primary
54+
@DependsOn("flyway2")
5255
@Bean
5356
public Flyway flyway1(DataSource dataSource) {
5457
Flyway flyway = new Flyway();
@@ -62,8 +65,8 @@ public Flyway flyway1(DataSource dataSource) {
6265
public Flyway flyway2(DataSource dataSource) {
6366
Flyway flyway = new Flyway();
6467
flyway.setDataSource(dataSource);
65-
flyway.setSchemas("test");
66-
flyway.setLocations("db/test_migration/separated");
68+
flyway.setSchemas("next");
69+
flyway.setLocations("db/next_migration");
6770
return flyway;
6871
}
6972

@@ -93,12 +96,18 @@ public JdbcTemplate jdbcTemplate(DataSource dataSource) {
9396
public void databaseShouldBeLoadedByFlyway1AndAppendedByFlyway3() {
9497
assertThat(dataSource).isNotNull();
9598

96-
List<Map<String, Object>> persons = jdbcTemplate.queryForList(SQL_SELECT_PERSONS);
99+
List<Map<String, Object>> persons = jdbcTemplate.queryForList("select * from test.person");
97100
assertThat(persons).isNotNull().hasSize(3);
98101

99102
assertThat(persons).extracting("id", "first_name", "last_name").containsExactlyInAnyOrder(
100103
tuple(1L, "Dave", "Syer"),
101104
tuple(2L, "Tom", "Hanks"),
102105
tuple(3L, "Will", "Smith"));
106+
107+
List<Map<String, Object>> nextPersons = jdbcTemplate.queryForList("select * from next.person");
108+
assertThat(nextPersons).isNotNull().hasSize(1);
109+
110+
assertThat(nextPersons).extracting("id", "first_name", "surname").containsExactlyInAnyOrder(
111+
tuple(1L, "Dave", "Syer"));
103112
}
104113
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
package io.zonky.test.db;
2+
3+
import io.zonky.test.category.FlywayIntegrationTests;
4+
import org.flywaydb.core.Flyway;
5+
import org.junit.Test;
6+
import org.junit.experimental.categories.Category;
7+
import org.junit.runner.RunWith;
8+
import org.springframework.beans.factory.annotation.Autowired;
9+
import org.springframework.context.annotation.Bean;
10+
import org.springframework.context.annotation.Configuration;
11+
import org.springframework.context.annotation.DependsOn;
12+
import org.springframework.context.annotation.Primary;
13+
import org.springframework.jdbc.core.JdbcTemplate;
14+
import org.springframework.test.context.ContextConfiguration;
15+
import org.springframework.test.context.junit4.SpringRunner;
16+
17+
import javax.sql.DataSource;
18+
import java.util.List;
19+
import java.util.Map;
20+
21+
import static org.assertj.core.api.Assertions.assertThat;
22+
import static org.assertj.core.api.Assertions.tuple;
23+
24+
@RunWith(SpringRunner.class)
25+
@Category(FlywayIntegrationTests.class)
26+
@AutoConfigureEmbeddedDatabase(beanName = "dataSource")
27+
@ContextConfiguration
28+
public class MultipleFlywayBeansContextInitializationIntegrationTest {
29+
30+
@Configuration
31+
static class Config {
32+
33+
@Primary
34+
@DependsOn("flyway2")
35+
@Bean(initMethod = "migrate")
36+
public Flyway flyway1(DataSource dataSource) {
37+
Flyway flyway = new Flyway();
38+
flyway.setDataSource(dataSource);
39+
flyway.setSchemas("test");
40+
flyway.setLocations("db/migration", "db/test_migration/dependent");
41+
return flyway;
42+
}
43+
44+
@Bean(initMethod = "migrate")
45+
public Flyway flyway2(DataSource dataSource) {
46+
Flyway flyway = new Flyway();
47+
flyway.setDataSource(dataSource);
48+
flyway.setSchemas("next");
49+
flyway.setLocations("db/next_migration");
50+
return flyway;
51+
}
52+
53+
@Bean(initMethod = "migrate")
54+
public Flyway flyway3(DataSource dataSource) {
55+
Flyway flyway = new Flyway();
56+
flyway.setDataSource(dataSource);
57+
flyway.setSchemas("test");
58+
flyway.setLocations("db/test_migration/appendable");
59+
flyway.setValidateOnMigrate(false);
60+
return flyway;
61+
}
62+
63+
@Bean
64+
public JdbcTemplate jdbcTemplate(DataSource dataSource) {
65+
return new JdbcTemplate(dataSource);
66+
}
67+
}
68+
69+
@Autowired
70+
private DataSource dataSource;
71+
72+
@Autowired
73+
private JdbcTemplate jdbcTemplate;
74+
75+
@Test
76+
public void databaseShouldBeLoadedByFlyway1AndAppendedByFlyway3() {
77+
assertThat(dataSource).isNotNull();
78+
79+
List<Map<String, Object>> persons = jdbcTemplate.queryForList("select * from test.person");
80+
assertThat(persons).isNotNull().hasSize(3);
81+
82+
assertThat(persons).extracting("id", "first_name", "last_name").containsExactlyInAnyOrder(
83+
tuple(1L, "Dave", "Syer"),
84+
tuple(2L, "Tom", "Hanks"),
85+
tuple(3L, "Will", "Smith"));
86+
87+
List<Map<String, Object>> nextPersons = jdbcTemplate.queryForList("select * from next.person");
88+
assertThat(nextPersons).isNotNull().hasSize(1);
89+
90+
assertThat(nextPersons).extracting("id", "first_name", "surname").containsExactlyInAnyOrder(
91+
tuple(1L, "Dave", "Syer"));
92+
}
93+
}

0 commit comments

Comments
 (0)