|
34 | 34 | import io.zonky.test.db.provider.support.SimpleDatabaseTemplate; |
35 | 35 | import io.zonky.test.db.util.PropertyUtils; |
36 | 36 | import io.zonky.test.db.util.RandomStringUtils; |
| 37 | +import io.zonky.test.db.util.ReflectionUtils; |
37 | 38 | import org.postgresql.ds.PGSimpleDataSource; |
38 | 39 | import org.postgresql.ds.common.BaseDataSource; |
39 | 40 | import org.slf4j.Logger; |
40 | 41 | import org.slf4j.LoggerFactory; |
41 | 42 | import org.springframework.beans.factory.ObjectProvider; |
42 | 43 | import org.springframework.core.env.Environment; |
43 | 44 | import org.springframework.jdbc.core.JdbcTemplate; |
| 45 | +import org.springframework.util.ClassUtils; |
44 | 46 |
|
45 | 47 | import javax.sql.DataSource; |
46 | 48 | import java.io.IOException; |
|
56 | 58 | import java.util.concurrent.CompletableFuture; |
57 | 59 | import java.util.concurrent.ExecutionException; |
58 | 60 | import java.util.concurrent.Semaphore; |
| 61 | +import java.util.concurrent.atomic.AtomicBoolean; |
59 | 62 | import java.util.function.Consumer; |
60 | 63 |
|
| 64 | +import static io.zonky.test.db.util.ReflectionUtils.getField; |
61 | 65 | import static java.util.Collections.emptyList; |
62 | 66 |
|
63 | 67 | public class ZonkyPostgresDatabaseProvider implements TemplatableDatabaseProvider { |
@@ -133,6 +137,7 @@ private DatabaseInstance(DatabaseConfig config) throws IOException { |
133 | 137 | config.applyTo(builder); |
134 | 138 |
|
135 | 139 | postgres = builder.start(); |
| 140 | + registerShutdownHook(postgres); |
136 | 141 |
|
137 | 142 | DataSource dataSource = postgres.getDatabase("postgres", "postgres"); |
138 | 143 | JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource); |
@@ -190,6 +195,29 @@ private EmbeddedDatabase getDatabase(ClientConfig config, String dbName) { |
190 | 195 | PGSimpleDataSource dataSource = (PGSimpleDataSource) postgres.getDatabase("postgres", dbName, config.connectProperties); |
191 | 196 | return new BlockingDatabaseWrapper(new PostgresEmbeddedDatabase(dataSource, () -> dropDatabase(config, dbName)), semaphore); |
192 | 197 | } |
| 198 | + |
| 199 | + protected void registerShutdownHook(EmbeddedPostgres postgres) { |
| 200 | + try { |
| 201 | + AtomicBoolean closed = getField(postgres, "closed"); |
| 202 | + |
| 203 | + Runnable shutdownHandler = () -> { |
| 204 | + try { |
| 205 | + closed.set(false); |
| 206 | + postgres.close(); |
| 207 | + } catch (IOException e) { |
| 208 | + logger.error("Unexpected IOException when closing PostgreSQL server", e); |
| 209 | + } |
| 210 | + }; |
| 211 | + |
| 212 | + Class<?> applicationType = ClassUtils.forName("org.springframework.boot.SpringApplication", null); |
| 213 | + Object shutdownHandlers = ReflectionUtils.invokeStaticMethod(applicationType, "getShutdownHandlers"); |
| 214 | + ReflectionUtils.invokeMethod(shutdownHandlers, "add", shutdownHandler); |
| 215 | + |
| 216 | + closed.set(true); |
| 217 | + } catch (Throwable ex) { |
| 218 | + // ClassNotFoundException or NoClassDefFoundError... |
| 219 | + } |
| 220 | + } |
193 | 221 | } |
194 | 222 |
|
195 | 223 | private static class DatabaseConfig { |
|
0 commit comments