Skip to content

Commit ce93f35

Browse files
Add restrict key support (#237)
--------- Co-authored-by: Velagent <[email protected]>
1 parent 72927ab commit ce93f35

File tree

6 files changed

+112
-5
lines changed

6 files changed

+112
-5
lines changed

cmd/pg-schema-diff/apply_cmd_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ func (suite *cmdTestSuite) TestApplyCmd() {
7878
})
7979
// The migration should have been successful. Assert it was.
8080
expectedDb := tempDbWithSchema(suite.T(), suite.pgEngine, tc.expectedSchemaDDL)
81-
expectedDbDump, err := pgdump.GetDump(expectedDb, pgdump.WithSchemaOnly())
81+
expectedDbDump, err := pgdump.GetDump(expectedDb, pgdump.WithSchemaOnly(), pgdump.WithRestrictKey(pgdump.FixedRestrictKey))
8282
suite.Require().NoError(err)
83-
fromDbDump, err := pgdump.GetDump(fromDb, pgdump.WithSchemaOnly())
83+
fromDbDump, err := pgdump.GetDump(fromDb, pgdump.WithSchemaOnly(), pgdump.WithRestrictKey(pgdump.FixedRestrictKey))
8484
suite.Require().NoError(err)
8585

8686
suite.Equal(expectedDbDump, fromDbDump)

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ require (
66
github.com/go-logfmt/logfmt v0.6.0
77
github.com/google/go-cmp v0.5.9
88
github.com/google/uuid v1.3.0
9+
github.com/hashicorp/go-version v1.7.0
910
github.com/jackc/pgx/v4 v4.18.2
1011
github.com/kr/pretty v0.3.1
1112
github.com/lib/pq v1.10.2

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN
2929
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
3030
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
3131
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
32+
github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKeRZfjY=
33+
github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
3234
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
3335
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
3436
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=

internal/migration_acceptance_tests/acceptance_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ func runTest(t *testing.T, tc acceptanceTestCase) {
183183

184184
// Make sure the pgdump after running the migration is the same as the
185185
// pgdump from a database where we directly run the newSchemaDDL
186-
oldDbDump, err := pgdump.GetDump(oldDb, pgdump.WithSchemaOnly())
186+
oldDbDump, err := pgdump.GetDump(oldDb, pgdump.WithSchemaOnly(), pgdump.WithRestrictKey(pgdump.FixedRestrictKey))
187187
require.NoError(t, err)
188188

189189
newDbDump := directlyRunDDLAndGetDump(t, engine, tc.expectedDBSchemaDDL)
@@ -221,7 +221,7 @@ func directlyRunDDLAndGetDump(t *testing.T, engine *pgengine.Engine, ddl []strin
221221
defer newDb.DropDB()
222222
require.NoError(t, applyDDL(newDb, ddl))
223223

224-
newDbDump, err := pgdump.GetDump(newDb, pgdump.WithSchemaOnly())
224+
newDbDump, err := pgdump.GetDump(newDb, pgdump.WithSchemaOnly(), pgdump.WithRestrictKey(pgdump.FixedRestrictKey))
225225
require.NoError(t, err)
226226
return newDbDump
227227
}

internal/pgdump/dump.go

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,32 @@ import (
44
"errors"
55
"fmt"
66
"os/exec"
7+
"regexp"
78

9+
"github.com/hashicorp/go-version"
810
"github.com/stripe/pg-schema-diff/internal/pgengine"
911
)
1012

13+
const (
14+
// FixedRestrictKey is a constant restricted key that can be used for tests and other use cases where
15+
// a constant restrict key is needed.
16+
FixedRestrictKey = "pgschemadiffrestrict"
17+
)
18+
19+
var (
20+
// versionRe matches the version returned by pg_dump.
21+
versionRe = regexp.MustCompile(`pg_dump \(PostgreSQL\) (\d+(?:\.\d+)?)`)
22+
23+
version15 = version.Must(version.NewSemver("15.0"))
24+
)
25+
1126
// Parameter represents a parameter to be pg_dump. Don't use a type alias for a string slice
1227
// because all parameters for pgdump should be explicitly added here
1328
type Parameter struct {
14-
values []string `explicit:"always"`
29+
values []string
30+
// minimumVersion is the minimum required version pg_dump must return for the parameter to be added. If
31+
// pg_dump is an older version, it will not be added. If nil, there is no restriction.
32+
minimumVersion *version.Version
1533
}
1634

1735
func WithExcludeSchema(pattern string) Parameter {
@@ -26,6 +44,17 @@ func WithSchemaOnly() Parameter {
2644
}
2745
}
2846

47+
// WithRestrictKey is used by PSQL to prevent injection of "meta" commands. If not explicitly provided,
48+
// a random one will be generated for each pg_dump run. This most likely needs to be fixed for any
49+
// usages of pg_dump in tests.
50+
func WithRestrictKey(restrictKey string) Parameter {
51+
return Parameter{
52+
values: []string{"--restrict-key", restrictKey},
53+
// Added in 17.6. https://www.postgresql.org/docs/release/17.6/.
54+
minimumVersion: version15,
55+
}
56+
}
57+
2958
// GetDump gets the pg_dump of the inputted database.
3059
// It is only intended to be used for testing. You cannot securely pass passwords with this implementation, so it will
3160
// only accept databases created for unit tests (spun up with the pgengine package)
@@ -39,13 +68,50 @@ func GetDump(db *pgengine.DB, additionalParams ...Parameter) (string, error) {
3968
}
4069

4170
func GetDumpUsingBinary(pgDumpBinaryPath string, db *pgengine.DB, additionalParams ...Parameter) (string, error) {
71+
version, err := getVersion(pgDumpBinaryPath)
72+
if err != nil {
73+
return "", fmt.Errorf("getVersion: %w", err)
74+
}
75+
4276
params := []string{
4377
db.GetDSN(),
4478
}
4579
for _, param := range additionalParams {
80+
if param.minimumVersion != nil && param.minimumVersion.GreaterThan(version) {
81+
// Exclude the parameter if the minimum version is not satisfied.
82+
continue
83+
}
4684
params = append(params, param.values...)
4785
}
86+
return runPgDumpCmd(pgDumpBinaryPath, params...)
87+
}
88+
89+
// ParseVersion parses a version string from pg_dump output and returns a Version object.
90+
// This function is exported to make it testable.
91+
func ParseVersion(versionString string) (*version.Version, error) {
92+
matches := versionRe.FindStringSubmatch(versionString)
93+
if len(matches) < 2 {
94+
return nil, fmt.Errorf("could not extract version from string: %s", versionString)
95+
}
96+
97+
// Parse the extracted version string
98+
v, err := version.NewVersion(matches[1])
99+
if err != nil {
100+
return nil, fmt.Errorf("could not parse version %s: %w", matches[1], err)
101+
}
102+
return v, nil
103+
}
104+
105+
func getVersion(pgDumpBinaryPath string) (*version.Version, error) {
106+
versionString, err := runPgDumpCmd(pgDumpBinaryPath, "--version")
107+
if err != nil {
108+
return nil, err
109+
}
110+
111+
return ParseVersion(versionString)
112+
}
48113

114+
func runPgDumpCmd(pgDumpBinaryPath string, params ...string) (string, error) {
49115
output, err := exec.Command(pgDumpBinaryPath, params...).CombinedOutput()
50116
if err != nil {
51117
return "", fmt.Errorf("running pg dump \noutput=%s\n: %w", output, err)

internal/pgdump/dump_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,41 @@ func TestGetDump(t *testing.T) {
5151
require.NotContains(t, onlyPublicSchemaDump, "test.bar")
5252
require.NotContains(t, onlyPublicSchemaDump, "some-id")
5353
}
54+
55+
func TestParseVersion(t *testing.T) {
56+
testCases := []struct {
57+
name string
58+
versionString string
59+
expectedVersion string
60+
expectError bool
61+
}{
62+
{
63+
name: "version 17.6",
64+
versionString: "pg_dump (PostgreSQL) 17.6",
65+
expectedVersion: "17.6.0",
66+
expectError: false,
67+
},
68+
{
69+
name: "version 17",
70+
versionString: "pg_dump (PostgreSQL) 17",
71+
expectedVersion: "17.0.0",
72+
expectError: false,
73+
},
74+
{
75+
name: "invalid version string",
76+
versionString: "invalid version",
77+
expectError: true,
78+
},
79+
}
80+
81+
for _, tc := range testCases {
82+
t.Run(tc.name, func(t *testing.T) {
83+
version, err := pgdump.ParseVersion(tc.versionString)
84+
if tc.expectError {
85+
require.Error(t, err)
86+
return
87+
}
88+
require.Equal(t, tc.expectedVersion, version.String())
89+
})
90+
}
91+
}

0 commit comments

Comments
 (0)