Skip to content

Commit 9bdaf61

Browse files
committed
Be paranoid about concurrency
1 parent a4893b1 commit 9bdaf61

File tree

3 files changed

+121
-85
lines changed

3 files changed

+121
-85
lines changed

internal/sqltest/docker/enabled.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@ package docker
33
import (
44
"fmt"
55
"os/exec"
6+
7+
"golang.org/x/sync/singleflight"
68
)
79

10+
var flight singleflight.Group
11+
812
func Installed() error {
913
if _, err := exec.LookPath("docker"); err != nil {
1014
return fmt.Errorf("docker not found: %w", err)

internal/sqltest/docker/mysql.go

Lines changed: 60 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,84 +5,100 @@ import (
55
"database/sql"
66
"fmt"
77
"os/exec"
8-
"sync"
8+
"strings"
99
"time"
1010

1111
_ "github.com/go-sql-driver/mysql"
1212
)
1313

14-
var mysqlSync sync.Once
1514
var mysqlHost string
1615

1716
func StartMySQLServer(c context.Context) (string, error) {
1817
if err := Installed(); err != nil {
1918
return "", err
2019
}
20+
if mysqlHost != "" {
21+
return mysqlHost, nil
22+
}
23+
value, err, _ := flight.Do("mysql", func() (interface{}, error) {
24+
host, err := startMySQLServer(c)
25+
if err != nil {
26+
return "", err
27+
}
28+
mysqlHost = host
29+
return host, nil
30+
})
31+
if err != nil {
32+
return "", err
33+
}
34+
data, ok := value.(string)
35+
if !ok {
36+
return "", fmt.Errorf("returned value was not a string")
37+
}
38+
return data, nil
39+
}
2140

41+
func startMySQLServer(c context.Context) (string, error) {
2242
{
23-
_, err := exec.Command("docker", "pull", "mysql:8").CombinedOutput()
43+
_, err := exec.Command("docker", "pull", "mysql:9").CombinedOutput()
2444
if err != nil {
25-
return "", fmt.Errorf("docker pull: mysql:8 %w", err)
45+
return "", fmt.Errorf("docker pull: mysql:9 %w", err)
2646
}
2747
}
2848

29-
var syncErr error
30-
mysqlSync.Do(func() {
31-
ctx, cancel := context.WithTimeout(c, 10*time.Second)
32-
defer cancel()
49+
var exists bool
50+
{
51+
cmd := exec.Command("docker", "container", "inspect", "sqlc_sqltest_docker_mysql")
52+
// This means we've already started the container
53+
exists = cmd.Run() == nil
54+
}
3355

56+
if !exists {
3457
cmd := exec.Command("docker", "run",
3558
"--name", "sqlc_sqltest_docker_mysql",
3659
"-e", "MYSQL_ROOT_PASSWORD=mysecretpassword",
3760
"-e", "MYSQL_DATABASE=dinotest",
3861
"-p", "3306:3306",
3962
"-d",
40-
"mysql:8",
63+
"mysql:9",
4164
)
4265

4366
output, err := cmd.CombinedOutput()
4467
fmt.Println(string(output))
45-
if err != nil {
46-
syncErr = err
47-
return
48-
}
4968

50-
// Create a ticker that fires every 10ms
51-
ticker := time.NewTicker(10 * time.Millisecond)
52-
defer ticker.Stop()
69+
msg := `Conflict. The container name "/sqlc_sqltest_docker_mysql" is already in use by container`
70+
if !strings.Contains(string(output), msg) && err != nil {
71+
return "", err
72+
}
73+
}
5374

54-
uri := "root:mysecretpassword@/dinotest"
75+
ctx, cancel := context.WithTimeout(c, 10*time.Second)
76+
defer cancel()
5577

56-
db, err := sql.Open("mysql", uri)
57-
if err != nil {
58-
syncErr = fmt.Errorf("sql.Open: %w", err)
59-
return
60-
}
78+
// Create a ticker that fires every 10ms
79+
ticker := time.NewTicker(10 * time.Millisecond)
80+
defer ticker.Stop()
6181

62-
for {
63-
select {
64-
case <-ctx.Done():
65-
syncErr = fmt.Errorf("timeout reached: %w", ctx.Err())
66-
return
67-
68-
case <-ticker.C:
69-
// Run your function here
70-
if err := db.PingContext(ctx); err != nil {
71-
continue
72-
}
73-
mysqlHost = uri
74-
return
75-
}
76-
}
77-
})
82+
uri := "root:mysecretpassword@/dinotest?multiStatements=true&parseTime=true"
7883

79-
if syncErr != nil {
80-
return "", syncErr
84+
db, err := sql.Open("mysql", uri)
85+
if err != nil {
86+
return "", fmt.Errorf("sql.Open: %w", err)
8187
}
8288

83-
if mysqlHost == "" {
84-
return "", fmt.Errorf("mysql server setup failed")
85-
}
89+
defer db.Close()
90+
91+
for {
92+
select {
93+
case <-ctx.Done():
94+
return "", fmt.Errorf("timeout reached: %w", ctx.Err())
8695

87-
return mysqlHost, nil
96+
case <-ticker.C:
97+
// Run your function here
98+
if err := db.PingContext(ctx); err != nil {
99+
continue
100+
}
101+
return uri, nil
102+
}
103+
}
88104
}

internal/sqltest/docker/postgres.go

Lines changed: 57 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,57 @@ import (
55
"fmt"
66
"log/slog"
77
"os/exec"
8-
"sync"
8+
"strings"
99
"time"
1010

1111
"github.com/jackc/pgx/v5"
1212
)
1313

14-
var postgresSync sync.Once
1514
var postgresHost string
1615

1716
func StartPostgreSQLServer(c context.Context) (string, error) {
1817
if err := Installed(); err != nil {
1918
return "", err
2019
}
20+
if postgresHost != "" {
21+
return postgresHost, nil
22+
}
23+
value, err, _ := flight.Do("postgresql", func() (interface{}, error) {
24+
host, err := startPostgreSQLServer(c)
25+
if err != nil {
26+
return "", err
27+
}
28+
postgresHost = host
29+
return host, err
30+
})
31+
if err != nil {
32+
return "", err
33+
}
34+
data, ok := value.(string)
35+
if !ok {
36+
return "", fmt.Errorf("returned value was not a string")
37+
}
38+
return data, nil
39+
}
2140

41+
func startPostgreSQLServer(c context.Context) (string, error) {
2242
{
2343
_, err := exec.Command("docker", "pull", "postgres:16").CombinedOutput()
2444
if err != nil {
2545
return "", fmt.Errorf("docker pull: postgres:16 %w", err)
2646
}
2747
}
2848

29-
var syncErr error
30-
postgresSync.Do(func() {
31-
ctx, cancel := context.WithTimeout(c, 5*time.Second)
32-
defer cancel()
49+
uri := "postgres://postgres:mysecretpassword@localhost:5432/postgres?sslmode=disable"
50+
51+
var exists bool
52+
{
53+
cmd := exec.Command("docker", "container", "inspect", "sqlc_sqltest_docker_postgres")
54+
// This means we've already started the container
55+
exists = cmd.Run() == nil
56+
}
3357

58+
if !exists {
3459
cmd := exec.Command("docker", "run",
3560
"--name", "sqlc_sqltest_docker_postgres",
3661
"-e", "POSTGRES_PASSWORD=mysecretpassword",
@@ -43,47 +68,38 @@ func StartPostgreSQLServer(c context.Context) (string, error) {
4368

4469
output, err := cmd.CombinedOutput()
4570
fmt.Println(string(output))
46-
if err != nil {
47-
syncErr = err
48-
return
71+
72+
msg := `Conflict. The container name "/sqlc_sqltest_docker_postgres" is already in use by container`
73+
if !strings.Contains(string(output), msg) && err != nil {
74+
return "", err
4975
}
76+
}
5077

51-
// Create a ticker that fires every 10ms
52-
ticker := time.NewTicker(10 * time.Millisecond)
53-
defer ticker.Stop()
78+
ctx, cancel := context.WithTimeout(c, 5*time.Second)
79+
defer cancel()
5480

55-
uri := "postgres://postgres:mysecretpassword@localhost:5432/postgres?sslmode=disable"
81+
// Create a ticker that fires every 10ms
82+
ticker := time.NewTicker(10 * time.Millisecond)
83+
defer ticker.Stop()
5684

57-
for {
58-
select {
59-
case <-ctx.Done():
60-
syncErr = fmt.Errorf("timeout reached: %w", ctx.Err())
61-
return
85+
for {
86+
select {
87+
case <-ctx.Done():
88+
return "", fmt.Errorf("timeout reached: %w", ctx.Err())
6289

63-
case <-ticker.C:
64-
// Run your function here
65-
conn, err := pgx.Connect(ctx, uri)
66-
if err != nil {
67-
slog.Debug("sqltest", "connect", err)
68-
continue
69-
}
70-
if err := conn.Ping(ctx); err != nil {
71-
slog.Error("sqltest", "ping", err)
72-
continue
73-
}
74-
postgresHost = uri
75-
return
90+
case <-ticker.C:
91+
// Run your function here
92+
conn, err := pgx.Connect(ctx, uri)
93+
if err != nil {
94+
slog.Debug("sqltest", "connect", err)
95+
continue
96+
}
97+
defer conn.Close(ctx)
98+
if err := conn.Ping(ctx); err != nil {
99+
slog.Error("sqltest", "ping", err)
100+
continue
76101
}
102+
return uri, nil
77103
}
78-
})
79-
80-
if syncErr != nil {
81-
return "", syncErr
82104
}
83-
84-
if postgresHost == "" {
85-
return "", fmt.Errorf("postgres server setup failed")
86-
}
87-
88-
return postgresHost, nil
89105
}

0 commit comments

Comments
 (0)