1
+ package com.example.dbtest
2
+
3
+ import org.junit.jupiter.api.extension.AfterEachCallback
4
+ import org.junit.jupiter.api.extension.BeforeEachCallback
5
+ import org.junit.jupiter.api.extension.ExtensionContext
6
+ import org.junit.jupiter.api.extension.ParameterContext
7
+ import org.junit.jupiter.api.extension.ParameterResolver
8
+ import java.nio.file.Files
9
+ import java.nio.file.Paths
10
+ import java.sql.Connection
11
+ import java.sql.DriverManager
12
+
13
+ const val schema = " dinosql_test"
14
+
15
+ class DbTestExtension (private val migrationsPath : String ) : BeforeEachCallback, AfterEachCallback, ParameterResolver {
16
+ private val schemaConn: Connection
17
+ private val url: String
18
+
19
+ init {
20
+ val user = System .getenv(" PG_USER" ) ? : " postgres"
21
+ val pass = System .getenv(" PG_PASSWORD" ) ? : " mysecretpassword"
22
+ val host = System .getenv(" PG_HOST" ) ? : " 127.0.0.1"
23
+ val port = System .getenv(" PG_PORT" ) ? : " 5432"
24
+ val db = System .getenv(" PG_DATABASE" ) ? : " dinotest"
25
+ url = " jdbc:postgresql://$host :$port /$db ?user=$user &password=$pass &sslmode=disable"
26
+
27
+ schemaConn = DriverManager .getConnection(url)
28
+ }
29
+
30
+ override fun beforeEach (context : ExtensionContext ) {
31
+ schemaConn.createStatement().execute(" CREATE SCHEMA $schema " )
32
+ val stmt = Files .readString(Paths .get(migrationsPath))
33
+ getConnection().createStatement().execute(stmt)
34
+ }
35
+
36
+ override fun afterEach (context : ExtensionContext ) {
37
+ schemaConn.createStatement().execute(" DROP SCHEMA $schema CASCADE" )
38
+ }
39
+
40
+ private fun getConnection (): Connection {
41
+ return DriverManager .getConnection(" $url ¤tSchema=$schema " )
42
+ }
43
+
44
+ override fun supportsParameter (parameterContext : ParameterContext , extensionContext : ExtensionContext ): Boolean {
45
+ return parameterContext.parameter.type == Connection ::class .java
46
+ }
47
+
48
+ override fun resolveParameter (parameterContext : ParameterContext , extensionContext : ExtensionContext ): Any {
49
+ return getConnection()
50
+ }
51
+ }
0 commit comments