Skip to content

Commit 7e28f8a

Browse files
authored
Merge pull request #72 from zonkyio/reflection-npe
#70 Fix NPE when debug logging is enabled on ReflectionTestUtils
2 parents 74dea85 + 3b1ff3d commit 7e28f8a

File tree

6 files changed

+140
-47
lines changed

6 files changed

+140
-47
lines changed

embedded-database-spring-test/src/main/java/io/zonky/test/db/flyway/FlywayClassUtils.java

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,9 @@
88
import org.springframework.core.io.ClassPathResource;
99
import org.springframework.util.ClassUtils;
1010

11-
import java.lang.reflect.InvocationTargetException;
12-
13-
import static org.apache.commons.lang3.reflect.MethodUtils.invokeStaticMethod;
14-
import static org.springframework.test.util.ReflectionTestUtils.getField;
15-
import static org.springframework.test.util.ReflectionTestUtils.invokeMethod;
11+
import static io.zonky.test.db.util.ReflectionUtils.getField;
12+
import static io.zonky.test.db.util.ReflectionUtils.invokeMethod;
13+
import static io.zonky.test.db.util.ReflectionUtils.invokeStaticMethod;
1614

1715
public class FlywayClassUtils {
1816

@@ -41,7 +39,7 @@ public class FlywayClassUtils {
4139
LoggerFactory.getLogger(FlywayConfigSnapshot.class).error("Unexpected error occurred while resolving flyway version", e);
4240
version = "0";
4341
}
44-
flywayVersion = Integer.valueOf(version);
42+
flywayVersion = Integer.parseInt(version);
4543

4644
if (flywayVersion >= 50) {
4745
boolean isCommercial;
@@ -58,8 +56,6 @@ public class FlywayClassUtils {
5856
isCommercial = true;
5957
} catch (FlywayException e) {
6058
isCommercial = false;
61-
} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
62-
throw new RuntimeException(e);
6359
}
6460
isFlywayPro = isCommercial;
6561
} else {

embedded-database-spring-test/src/main/java/io/zonky/test/db/flyway/FlywayConfigSnapshot.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import org.flywaydb.core.Flyway;
66
import org.flywaydb.core.api.MigrationVersion;
77
import org.flywaydb.core.api.resolver.MigrationResolver;
8-
import org.springframework.test.util.ReflectionTestUtils;
98

109
import javax.sql.DataSource;
1110
import java.util.Arrays;
@@ -14,7 +13,8 @@
1413
import java.util.Objects;
1514
import java.util.stream.Collectors;
1615

17-
import static org.springframework.test.util.ReflectionTestUtils.getField;
16+
import static io.zonky.test.db.util.ReflectionUtils.getField;
17+
import static io.zonky.test.db.util.ReflectionUtils.invokeMethod;
1818

1919
/**
2020
* Represents an <b>immutable</b> snapshot of Flyway's configuration.
@@ -221,7 +221,7 @@ public FlywayConfigSnapshot(Flyway flyway) {
221221
.collect(Collectors.toList()));
222222
} else {
223223
this.tablespace = null;
224-
this.javaMigrations = ImmutableList.of();;
224+
this.javaMigrations = ImmutableList.of();
225225
}
226226

227227
if (flywayVersion >= 60 && isFlywayPro) {
@@ -232,15 +232,15 @@ public FlywayConfigSnapshot(Flyway flyway) {
232232
}
233233

234234
private static <T> T getValue(Object target, String method) {
235-
return ReflectionTestUtils.invokeMethod(target, method);
235+
return invokeMethod(target, method);
236236
}
237237

238238
private static <E> E[] getArray(Object target, String method) {
239-
return ReflectionTestUtils.invokeMethod(target, method);
239+
return invokeMethod(target, method);
240240
}
241241

242242
private static <K, V> Map<K, V> getMap(Object target, String method) {
243-
return ReflectionTestUtils.invokeMethod(target, method);
243+
return invokeMethod(target, method);
244244
}
245245

246246
public ClassLoader getClassLoader() {

embedded-database-spring-test/src/main/java/io/zonky/test/db/flyway/OptimizedFlywayTestExecutionListener.java

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,16 @@
4242
import javax.sql.DataSource;
4343
import java.lang.annotation.Annotation;
4444
import java.lang.reflect.AnnotatedElement;
45-
import java.lang.reflect.Constructor;
46-
import java.lang.reflect.InvocationTargetException;
4745
import java.lang.reflect.Method;
4846
import java.sql.SQLException;
4947
import java.util.Arrays;
5048
import java.util.Collection;
5149
import java.util.Map;
5250

53-
import static org.apache.commons.lang3.reflect.MethodUtils.invokeStaticMethod;
54-
import static org.springframework.test.util.ReflectionTestUtils.getField;
55-
import static org.springframework.test.util.ReflectionTestUtils.invokeMethod;
51+
import static io.zonky.test.db.util.ReflectionUtils.getField;
52+
import static io.zonky.test.db.util.ReflectionUtils.invokeConstructor;
53+
import static io.zonky.test.db.util.ReflectionUtils.invokeMethod;
54+
import static io.zonky.test.db.util.ReflectionUtils.invokeStaticMethod;
5655

5756
/**
5857
* Optimized implementation of the {@link org.flywaydb.test.junit.FlywayTestExecutionListener}
@@ -200,7 +199,7 @@ protected static DataSource reloadDataSource(FlywayDataSourceContext dataSourceC
200199
/**
201200
* Checks if test migrations are appendable to core migrations.
202201
*/
203-
protected static boolean isAppendable(Flyway flyway, FlywayTest annotation) throws ClassNotFoundException, NoSuchMethodException, InstantiationException, IllegalAccessException, InvocationTargetException {
202+
protected static boolean isAppendable(Flyway flyway, FlywayTest annotation) throws ClassNotFoundException {
204203
if (annotation.overrideLocations()) {
205204
return false;
206205
}
@@ -218,7 +217,7 @@ protected static boolean isAppendable(Flyway flyway, FlywayTest annotation) thro
218217
return coreVersion.compareTo(testVersion) < 0;
219218
}
220219

221-
protected static MigrationVersion findFirstVersion(Flyway flyway, String... locations) throws ClassNotFoundException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException {
220+
protected static MigrationVersion findFirstVersion(Flyway flyway, String... locations) throws ClassNotFoundException {
222221
Collection<ResolvedMigration> migrations = resolveMigrations(flyway, locations);
223222
return migrations.stream()
224223
.filter(migration -> migration.getVersion() != null)
@@ -227,7 +226,7 @@ protected static MigrationVersion findFirstVersion(Flyway flyway, String... loca
227226
.orElse(MigrationVersion.EMPTY);
228227
}
229228

230-
protected static MigrationVersion findLastVersion(Flyway flyway, String... locations) throws ClassNotFoundException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException {
229+
protected static MigrationVersion findLastVersion(Flyway flyway, String... locations) throws ClassNotFoundException {
231230
Collection<ResolvedMigration> migrations = resolveMigrations(flyway, locations);
232231
return migrations.stream()
233232
.filter(migration -> migration.getVersion() != null)
@@ -236,7 +235,7 @@ protected static MigrationVersion findLastVersion(Flyway flyway, String... locat
236235
.orElse(MigrationVersion.EMPTY);
237236
}
238237

239-
protected static Collection<ResolvedMigration> resolveMigrations(Flyway flyway, String... locations) throws ClassNotFoundException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException {
238+
protected static Collection<ResolvedMigration> resolveMigrations(Flyway flyway, String... locations) throws ClassNotFoundException {
240239
MigrationResolver resolver = createMigrationResolver(flyway, locations);
241240

242241
if (flywayVersion >= 52) {
@@ -250,33 +249,26 @@ protected static Collection<ResolvedMigration> resolveMigrations(Flyway flyway,
250249
}
251250
}
252251

253-
protected static MigrationResolver createMigrationResolver(Flyway flyway, String... locations) throws ClassNotFoundException, NoSuchMethodException, IllegalAccessException, InvocationTargetException, InstantiationException {
252+
protected static MigrationResolver createMigrationResolver(Flyway flyway, String... locations) throws ClassNotFoundException {
254253
String[] oldLocations = getFlywayLocations(flyway);
255254
try {
256255
setFlywayLocations(flyway, locations);
257256

258257
if (flywayVersion >= 60) {
259258
Object configuration = getField(flyway, "configuration");
260-
261-
Class<?> jdbcConnectionFactoryType = ClassUtils.forName("org.flywaydb.core.internal.jdbc.JdbcConnectionFactory", classLoader);
262-
Object jdbcConnectionFactory = jdbcConnectionFactoryType.getConstructors()[0].newInstance(
263-
invokeMethod(configuration, "getDataSource"), 0);
264-
259+
Object jdbcConnectionFactory = invokeConstructor("org.flywaydb.core.internal.jdbc.JdbcConnectionFactory", invokeMethod(configuration, "getDataSource"), 0);
265260
Object sqlScriptFactory = invokeStaticMethod(DatabaseFactory.class, "createSqlScriptFactory", jdbcConnectionFactory, configuration);
266261
Object sqlScriptExecutorFactory = invokeStaticMethod(DatabaseFactory.class, "createSqlScriptExecutorFactory", jdbcConnectionFactory);
267-
268-
Class<?> scannerType = ClassUtils.forName("org.flywaydb.core.internal.scanner.Scanner", classLoader);
269-
Constructor<?> scannerConstructor = scannerType.getConstructors()[0];
270262
Object scanner;
271263

272-
if (scannerConstructor.getParameterCount() == 4) {
273-
scanner = scannerConstructor.newInstance(
264+
try {
265+
scanner = invokeConstructor("org.flywaydb.core.internal.scanner.Scanner",
274266
ClassUtils.forName("org.flywaydb.core.api.migration.JavaMigration", classLoader),
275267
Arrays.asList((Object[]) invokeMethod(configuration, "getLocations")),
276268
invokeMethod(configuration, "getClassLoader"),
277269
invokeMethod(configuration, "getEncoding"));
278-
} else {
279-
scanner = scannerConstructor.newInstance(
270+
} catch (RuntimeException ex) {
271+
scanner = invokeConstructor("org.flywaydb.core.internal.scanner.Scanner",
280272
Arrays.asList((Object[]) invokeMethod(configuration, "getLocations")),
281273
invokeMethod(configuration, "getClassLoader"),
282274
invokeMethod(configuration, "getEncoding"));
@@ -287,15 +279,14 @@ protected static MigrationResolver createMigrationResolver(Flyway flyway, String
287279
Object configuration = getField(flyway, "configuration");
288280
Object database = invokeStaticMethod(DatabaseFactory.class, "createDatabase", flyway, false);
289281
Object factory = invokeMethod(database, "createSqlStatementBuilderFactory");
290-
Class<?> scannerType = ClassUtils.forName("org.flywaydb.core.internal.scanner.Scanner", classLoader);
291-
Object scanner = scannerType.getConstructors()[0].newInstance(
282+
Object scanner = invokeConstructor("org.flywaydb.core.internal.scanner.Scanner",
292283
Arrays.asList((Object[]) invokeMethod(configuration, "getLocations")),
293284
invokeMethod(configuration, "getClassLoader"),
294285
invokeMethod(configuration, "getEncoding"));
295286
return invokeMethod(flyway, "createMigrationResolver", database, scanner, scanner, factory);
296287
} else if (flywayVersion >= 51) {
297288
Object configuration = getField(flyway, "configuration");
298-
Object scanner = Scanner.class.getConstructors()[0].newInstance(configuration);
289+
Object scanner = invokeConstructor(Scanner.class, configuration);
299290
Object placeholderReplacer = invokeMethod(flyway, "createPlaceholderReplacer");
300291
return invokeMethod(flyway, "createMigrationResolver", null, scanner, placeholderReplacer);
301292
} else if (flywayVersion >= 40) {
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
package io.zonky.test.db.util;
2+
3+
import org.springframework.test.util.AopTestUtils;
4+
import org.springframework.util.Assert;
5+
import org.springframework.util.ClassUtils;
6+
import org.springframework.util.MethodInvoker;
7+
8+
import java.lang.reflect.Constructor;
9+
import java.lang.reflect.Field;
10+
import java.util.stream.IntStream;
11+
12+
public class ReflectionUtils {
13+
14+
private ReflectionUtils() {
15+
}
16+
17+
@SuppressWarnings("unchecked")
18+
public static <T> T getField(Object targetObject, String name) {
19+
Assert.notNull(targetObject, "Target object must not be null");
20+
21+
targetObject = AopTestUtils.getUltimateTargetObject(targetObject);
22+
Class<?> targetClass = targetObject.getClass();
23+
24+
Field field = org.springframework.util.ReflectionUtils.findField(targetClass, name);
25+
if (field == null) {
26+
throw new IllegalArgumentException(String.format("Could not find field '%s' on %s", name, safeToString(targetObject)));
27+
}
28+
29+
org.springframework.util.ReflectionUtils.makeAccessible(field);
30+
return (T) org.springframework.util.ReflectionUtils.getField(field, targetObject);
31+
}
32+
33+
@SuppressWarnings("unchecked")
34+
public static <T> T invokeMethod(Object targetObject, String name, Object... args) {
35+
Assert.notNull(targetObject, "Target object must not be null");
36+
Assert.hasText(name, "Method name must not be empty");
37+
38+
try {
39+
MethodInvoker methodInvoker = new MethodInvoker();
40+
methodInvoker.setTargetObject(targetObject);
41+
methodInvoker.setTargetMethod(name);
42+
methodInvoker.setArguments(args);
43+
methodInvoker.prepare();
44+
return (T) methodInvoker.invoke();
45+
} catch (Exception ex) {
46+
org.springframework.util.ReflectionUtils.handleReflectionException(ex);
47+
throw new IllegalStateException("Should never get here");
48+
}
49+
}
50+
51+
@SuppressWarnings("unchecked")
52+
public static <T> T invokeStaticMethod(Class<?> targetClass, String name, Object... args) {
53+
Assert.notNull(targetClass, "Target class must not be null");
54+
Assert.hasText(name, "Method name must not be empty");
55+
56+
try {
57+
MethodInvoker methodInvoker = new MethodInvoker();
58+
methodInvoker.setTargetClass(targetClass);
59+
methodInvoker.setTargetMethod(name);
60+
methodInvoker.setArguments(args);
61+
methodInvoker.prepare();
62+
return (T) methodInvoker.invoke();
63+
} catch (Exception ex) {
64+
org.springframework.util.ReflectionUtils.handleReflectionException(ex);
65+
throw new IllegalStateException("Should never get here");
66+
}
67+
}
68+
69+
public static <T> T invokeConstructor(String className, Object... args) throws ClassNotFoundException {
70+
Assert.notNull(className, "Target class must not be null");
71+
72+
Class<?> targetClass = ClassUtils.forName(className, null);
73+
return invokeConstructor(targetClass, args);
74+
}
75+
76+
@SuppressWarnings("unchecked")
77+
public static <T> T invokeConstructor(Class<?> targetClass, Object... args) {
78+
Assert.notNull(targetClass, "Target class must not be null");
79+
80+
try {
81+
for (Constructor<?> constructor : targetClass.getDeclaredConstructors()) {
82+
if (constructor.getParameterCount() != args.length) {
83+
continue;
84+
}
85+
Class<?>[] parameterTypes = constructor.getParameterTypes();
86+
boolean parametersMatches = IntStream.range(0, args.length)
87+
.allMatch(i -> ClassUtils.isAssignableValue(parameterTypes[i], args[i]));
88+
if (parametersMatches) {
89+
org.springframework.util.ReflectionUtils.makeAccessible(constructor);
90+
return (T) constructor.newInstance(args);
91+
}
92+
}
93+
throw new IllegalArgumentException(String.format("Could not find constructor on %s", targetClass));
94+
} catch (Exception ex) {
95+
org.springframework.util.ReflectionUtils.handleReflectionException(ex);
96+
throw new IllegalStateException("Should never get here");
97+
}
98+
}
99+
100+
private static String safeToString(Object target) {
101+
try {
102+
return String.format("target object [%s]", target);
103+
} catch (Exception ex) {
104+
return String.format("target of type [%s] whose toString() method threw [%s]",
105+
(target != null ? target.getClass().getName() : "unknown"), ex);
106+
}
107+
}
108+
}

embedded-database-spring-test/src/test/java/io/zonky/test/db/provider/impl/YandexPostgresDatabaseProviderTest.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import io.zonky.test.db.provider.DatabasePreparer;
2121
import io.zonky.test.db.provider.DatabaseType;
2222
import io.zonky.test.db.provider.ProviderType;
23-
import org.junit.Ignore;
2423
import org.junit.Test;
2524
import org.postgresql.ds.PGSimpleDataSource;
2625
import org.springframework.jdbc.core.JdbcTemplate;

embedded-database-spring-test/src/test/java/io/zonky/test/util/FlywayTestUtils.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,29 @@
11
package io.zonky.test.util;
22

3+
import io.zonky.test.db.flyway.FlywayClassUtils;
34
import org.flywaydb.core.Flyway;
45
import org.springframework.util.CollectionUtils;
56

67
import javax.sql.DataSource;
78
import java.util.List;
89

9-
import io.zonky.test.db.flyway.FlywayClassUtils;
10-
10+
import static io.zonky.test.db.util.ReflectionUtils.invokeMethod;
11+
import static io.zonky.test.db.util.ReflectionUtils.invokeStaticMethod;
1112
import static java.util.Collections.emptyList;
12-
import static org.apache.commons.lang3.reflect.MethodUtils.invokeStaticMethod;
13-
import static org.springframework.test.util.ReflectionTestUtils.invokeMethod;
1413

1514
public class FlywayTestUtils {
1615

1716
private FlywayTestUtils() {}
1817

19-
public static Flyway createFlyway(DataSource dataSource, String schema) throws Exception {
18+
public static Flyway createFlyway(DataSource dataSource, String schema) {
2019
return createFlyway(dataSource, schema, emptyList());
2120
}
2221

23-
public static Flyway createFlyway(DataSource dataSource, String schema, List<String> locations) throws Exception {
22+
public static Flyway createFlyway(DataSource dataSource, String schema, List<String> locations) {
2423
return createFlyway(dataSource, schema, locations, true);
2524
}
2625

27-
public static Flyway createFlyway(DataSource dataSource, String schema, List<String> locations, boolean validateOnMigrate) throws Exception {
26+
public static Flyway createFlyway(DataSource dataSource, String schema, List<String> locations, boolean validateOnMigrate) {
2827
int flywayVersion = FlywayClassUtils.getFlywayVersion();
2928

3029
if (flywayVersion >= 60) {

0 commit comments

Comments
 (0)