Skip to content

Commit d17db0e

Browse files
committed
kotlin: factor out db setup extension
1 parent f1f5a41 commit d17db0e

File tree

2 files changed

+56
-31
lines changed

2 files changed

+56
-31
lines changed

examples/kotlin/src/test/kotlin/com/example/authors/QueriesImplTest.kt

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,15 @@
11
package com.example.authors
22

3-
import org.junit.jupiter.api.AfterEach
3+
import com.example.dbtest.DbTestExtension
44
import org.junit.jupiter.api.Assertions.assertEquals
5-
import org.junit.jupiter.api.BeforeEach
65
import org.junit.jupiter.api.Test
7-
import java.nio.file.Files
8-
import java.nio.file.Paths
6+
import org.junit.jupiter.api.extension.RegisterExtension
97
import java.sql.Connection
10-
import java.sql.DriverManager
118

12-
const val schema = "dinosql_test"
9+
class QueriesImplTest(private val conn: Connection) {
1310

14-
class QueriesImplTest {
15-
lateinit var schemaConn: Connection
16-
lateinit var conn: Connection
17-
18-
@BeforeEach
19-
fun setup() {
20-
val user = System.getenv("PG_USER") ?: "postgres"
21-
val pass = System.getenv("PG_PASSWORD") ?: "mysecretpassword"
22-
val host = System.getenv("PG_HOST") ?: "127.0.0.1"
23-
val port = System.getenv("PG_PORT") ?: "5432"
24-
val db = System.getenv("PG_DATABASE") ?: "dinotest"
25-
val url = "jdbc:postgresql://$host:$port/$db?user=$user&password=$pass&sslmode=disable"
26-
println("db: $url")
27-
28-
schemaConn = DriverManager.getConnection(url)
29-
schemaConn.createStatement().execute("CREATE SCHEMA $schema")
30-
31-
conn = DriverManager.getConnection("$url&currentSchema=$schema")
32-
val stmt = Files.readString(Paths.get("src/main/resources/schema.sql"))
33-
conn.createStatement().execute(stmt)
34-
}
35-
36-
@AfterEach
37-
fun teardown() {
38-
schemaConn.createStatement().execute("DROP SCHEMA $schema CASCADE")
11+
companion object {
12+
@JvmField @RegisterExtension val db = DbTestExtension("src/main/resources/schema.sql")
3913
}
4014

4115
@Test
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package com.example.dbtest
2+
3+
import org.junit.jupiter.api.extension.AfterEachCallback
4+
import org.junit.jupiter.api.extension.BeforeEachCallback
5+
import org.junit.jupiter.api.extension.ExtensionContext
6+
import org.junit.jupiter.api.extension.ParameterContext
7+
import org.junit.jupiter.api.extension.ParameterResolver
8+
import java.nio.file.Files
9+
import java.nio.file.Paths
10+
import java.sql.Connection
11+
import java.sql.DriverManager
12+
13+
const val schema = "dinosql_test"
14+
15+
class DbTestExtension(private val migrationsPath: String) : BeforeEachCallback, AfterEachCallback, ParameterResolver {
16+
private val schemaConn: Connection
17+
private val url: String
18+
19+
init {
20+
val user = System.getenv("PG_USER") ?: "postgres"
21+
val pass = System.getenv("PG_PASSWORD") ?: "mysecretpassword"
22+
val host = System.getenv("PG_HOST") ?: "127.0.0.1"
23+
val port = System.getenv("PG_PORT") ?: "5432"
24+
val db = System.getenv("PG_DATABASE") ?: "dinotest"
25+
url = "jdbc:postgresql://$host:$port/$db?user=$user&password=$pass&sslmode=disable"
26+
27+
schemaConn = DriverManager.getConnection(url)
28+
}
29+
30+
override fun beforeEach(context: ExtensionContext) {
31+
schemaConn.createStatement().execute("CREATE SCHEMA $schema")
32+
val stmt = Files.readString(Paths.get(migrationsPath))
33+
getConnection().createStatement().execute(stmt)
34+
}
35+
36+
override fun afterEach(context: ExtensionContext) {
37+
schemaConn.createStatement().execute("DROP SCHEMA $schema CASCADE")
38+
}
39+
40+
private fun getConnection(): Connection {
41+
return DriverManager.getConnection("$url&currentSchema=$schema")
42+
}
43+
44+
override fun supportsParameter(parameterContext: ParameterContext, extensionContext: ExtensionContext): Boolean {
45+
return parameterContext.parameter.type == Connection::class.java
46+
}
47+
48+
override fun resolveParameter(parameterContext: ParameterContext, extensionContext: ExtensionContext): Any {
49+
return getConnection()
50+
}
51+
}

0 commit comments

Comments
 (0)