diff --git a/README.md b/README.md index 92d6f38..b100a3e 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,7 @@ echo "CREATE TABLE bar (id varchar(255), message TEXT NOT NULL);" > schema/bar.s Apply the schema to a fresh database. [The connection string spec can be found here](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING). Setting the `PGPASSWORD` env var will override any password set in the connection string and is recommended. ```bash -pg-schema-diff apply --dsn "postgres://postgres:postgres@localhost:5432/postgres" --schema-dir schema +pg-schema-diff apply --from-dsn "postgres://postgres:postgres@localhost:5432/postgres" --to-dir schema ``` ## 2. Updating schema @@ -121,7 +121,7 @@ echo "CREATE INDEX message_idx ON bar(message)" >> schema/bar.sql Apply the schema. Any hazards in the generated plan must be approved ```bash -pg-schema-diff apply --dsn "postgres://postgres:postgres@localhost:5432/postgres" --schema-dir schema --allow-hazards INDEX_BUILD +pg-schema-diff apply --from-dsn "postgres://postgres:postgres@localhost:5432/postgres" --to-dir schema --allow-hazards INDEX_BUILD ``` # Using Library diff --git a/cmd/pg-schema-diff/apply_cmd.go b/cmd/pg-schema-diff/apply_cmd.go index aa3e039..590f123 100644 --- a/cmd/pg-schema-diff/apply_cmd.go +++ b/cmd/pg-schema-diff/apply_cmd.go @@ -20,8 +20,9 @@ func buildApplyCmd() *cobra.Command { Short: "Migrate your database to the match the inputted schema (apply the schema to the database)", } - connFlags := createConnFlags(cmd) - planFlags := createPlanFlags(cmd) + connFlags := createConnectionFlags(cmd, "from-", " The database to migrate") + toSchemaFlags := createSchemaSourceFlags(cmd, "to-") + planOptsFlags := createPlanOptionsFlags(cmd) allowedHazardsTypesStrs := cmd.Flags().StringSlice("allow-hazards", nil, "Specify the hazards that are allowed. Order does not matter, and duplicates are ignored. If the"+ " migration plan contains unwanted hazards (hazards not in this list), then the migration will fail to run"+ @@ -29,28 +30,41 @@ func buildApplyCmd() *cobra.Command { skipConfirmPrompt := cmd.Flags().Bool("skip-confirm-prompt", false, "Skips prompt asking for user to confirm before applying") cmd.RunE = func(cmd *cobra.Command, args []string) error { logger := log.SimpleLogger() - connConfig, err := parseConnConfig(*connFlags, logger) + + connConfig, err := parseConnectionFlags(connFlags) + if err != nil { + return err + } + fromSchema := dsnSchemaSource(connConfig) + + toSchema, err := parseSchemaSource(*toSchemaFlags) if err != nil { return err } - planConfig, err := parsePlanConfig(*planFlags) + planOptions, err := parsePlanOptions(*planOptsFlags) if err != nil { return err } cmd.SilenceUsage = true - plan, err := generatePlan(context.Background(), logger, connConfig, planConfig) + plan, err := generatePlan(cmd.Context(), generatePlanParameters{ + fromSchema: fromSchema, + toSchema: toSchema, + tempDbConnConfig: connConfig, + planOptions: planOptions, + logger: logger, + }) if err != nil { return err } else if len(plan.Statements) == 0 { - fmt.Println("Schema matches expected. No plan generated") + cmd.Println("Schema matches expected. No plan generated") return nil } - fmt.Println(header("Review plan")) - fmt.Print(planToPrettyS(plan), "\n\n") + cmd.Println(header("Review plan")) + cmd.Print(planToPrettyS(plan), "\n\n") if err := failIfHazardsNotAllowed(plan, *allowedHazardsTypesStrs); err != nil { return err @@ -67,10 +81,10 @@ func buildApplyCmd() *cobra.Command { } } - if err := runPlan(context.Background(), connConfig, plan); err != nil { + if err := runPlan(cmd.Context(), cmd, connConfig, plan); err != nil { return err } - fmt.Println("Schema applied successfully") + cmd.Println("Schema applied successfully") return nil } @@ -109,7 +123,7 @@ func failIfHazardsNotAllowed(plan diff.Plan, allowedHazardsTypesStrs []string) e return nil } -func runPlan(ctx context.Context, connConfig *pgx.ConnConfig, plan diff.Plan) error { +func runPlan(ctx context.Context, cmd *cobra.Command, connConfig *pgx.ConnConfig, plan diff.Plan) error { connPool, err := openDbWithPgxConfig(connConfig) if err != nil { return err @@ -129,8 +143,8 @@ func runPlan(ctx context.Context, connConfig *pgx.ConnConfig, plan diff.Plan) er // must be executed within its own transaction block. Postgres will error if you try to set a TRANSACTION-level // timeout for it. SESSION-level statement_timeouts are respected by `ADD INDEX CONCURRENTLY` for i, stmt := range plan.Statements { - fmt.Println(header(fmt.Sprintf("Executing statement %d", getDisplayableStmtIdx(i)))) - fmt.Printf("%s\n\n", statementToPrettyS(stmt)) + cmd.Println(header(fmt.Sprintf("Executing statement %d", getDisplayableStmtIdx(i)))) + cmd.Printf("%s\n\n", statementToPrettyS(stmt)) start := time.Now() if _, err := conn.ExecContext(ctx, fmt.Sprintf("SET SESSION statement_timeout = %d", stmt.Timeout.Milliseconds())); err != nil { return fmt.Errorf("setting statement timeout: %w", err) @@ -141,9 +155,9 @@ func runPlan(ctx context.Context, connConfig *pgx.ConnConfig, plan diff.Plan) er if _, err := conn.ExecContext(ctx, stmt.ToSQL()); err != nil { return fmt.Errorf("executing migration statement. the database maybe be in a dirty state: %s: %w", stmt, err) } - fmt.Printf("Finished executing statement. Duration: %s\n", time.Since(start)) + cmd.Printf("Finished executing statement. Duration: %s\n", time.Since(start)) } - fmt.Println(header("Complete")) + cmd.Println(header("Complete")) return nil } diff --git a/cmd/pg-schema-diff/apply_cmd_test.go b/cmd/pg-schema-diff/apply_cmd_test.go new file mode 100644 index 0000000..51d99fb --- /dev/null +++ b/cmd/pg-schema-diff/apply_cmd_test.go @@ -0,0 +1,89 @@ +package main + +import ( + "github.com/stripe/pg-schema-diff/internal/pgdump" + "github.com/stripe/pg-schema-diff/internal/pgengine" +) + +func (suite *cmdTestSuite) TestApplyCmd() { + // Non-comprehensive set of tests for the plan command. Not totally comprehensive to avoid needing to avoid + // hindering developer velocity when updating the command. + type testCase struct { + name string + // fromDbArg is an optional argument to override the default "--from-dsn" arg. + fromDbArg func(db *pgengine.DB) []string + args []string + // dynamicArgs is function that can be used to build args that are dynamic, i.e., + // saving schemas to a randomly generated temporary directory. + dynamicArgs []dArgGenerator + + outputContains []string + // expectedSchema is the schema that is expected to be in the database after the migration. + // If nil, the expected schema will be the fromDDL. + expectedSchemaDDL []string + // expectErrContains is a list of substrings that are expected to be contained in the error returned by + // cmd.RunE. This is DISTINCT from stdErr. + expectErrContains []string + } + for _, tc := range []testCase{ + { + name: "to dir", + dynamicArgs: []dArgGenerator{tempSchemaDirDArg("to-dir", []string{"CREATE TABLE foobar();"})}, + + expectedSchemaDDL: []string{"CREATE TABLE foobar();"}, + }, + { + name: "to dsn", + dynamicArgs: []dArgGenerator{tempDsnDArg(suite.pgEngine, "to-dsn", []string{"CREATE TABLE foobar();"})}, + + expectedSchemaDDL: []string{"CREATE TABLE foobar();"}, + }, + { + name: "from empty dsn", + fromDbArg: func(db *pgengine.DB) []string { + tempSetPqEnvVarsForDb(suite.T(), db) + return []string{"--from-empty-dsn"} + }, + dynamicArgs: []dArgGenerator{tempSchemaDirDArg("to-dir", []string{"CREATE TABLE foobar();"})}, + + expectedSchemaDDL: []string{"CREATE TABLE foobar();"}, + }, + { + name: "no to schema provided", + expectErrContains: []string{"must be set"}, + }, + { + name: "two to schemas provided", + args: []string{"--to-dir", "some-other-dir", "--to-dsn", "some-dsn"}, + expectErrContains: []string{"only one of"}, + }, + } { + suite.Run(tc.name, func() { + fromDb := tempDbWithSchema(suite.T(), suite.pgEngine, nil) + if tc.fromDbArg == nil { + tc.fromDbArg = func(db *pgengine.DB) []string { + return []string{"--from-dsn", db.GetDSN()} + } + } + args := append([]string{ + "apply", + "--skip-confirm-prompt", + }, tc.fromDbArg(fromDb)...) + args = append(args, tc.args...) + suite.runCmdWithAssertions(runCmdWithAssertionsParams{ + args: args, + dynamicArgs: tc.dynamicArgs, + outputContains: tc.outputContains, + expectErrContains: tc.expectErrContains, + }) + // The migration should have been successful. Assert it was. + expectedDb := tempDbWithSchema(suite.T(), suite.pgEngine, tc.expectedSchemaDDL) + expectedDbDump, err := pgdump.GetDump(expectedDb, pgdump.WithSchemaOnly()) + suite.Require().NoError(err) + fromDbDump, err := pgdump.GetDump(fromDb, pgdump.WithSchemaOnly()) + suite.Require().NoError(err) + + suite.Equal(expectedDbDump, fromDbDump) + }) + } +} diff --git a/cmd/pg-schema-diff/flags.go b/cmd/pg-schema-diff/flags.go index 8d6634b..22ef8cb 100644 --- a/cmd/pg-schema-diff/flags.go +++ b/cmd/pg-schema-diff/flags.go @@ -7,29 +7,51 @@ import ( "github.com/go-logfmt/logfmt" "github.com/jackc/pgx/v4" "github.com/spf13/cobra" - "github.com/stripe/pg-schema-diff/pkg/log" ) -type connFlags struct { - dsn string +type connectionFlags struct { + // dsn is the connection string for the database. + dsn string + dsnFlagName string + + // isEmptyDsnUsingPq indicates to connect via DSN using the pq environment variables and defaults. + isEmptyDsnUsingPq bool + isEmptyDsnUsingPqFlagName string } -func createConnFlags(cmd *cobra.Command) *connFlags { - flags := &connFlags{} +func createConnectionFlags(cmd *cobra.Command, prefix string, additionalHelp string) *connectionFlags { + var c connectionFlags + + c.dsnFlagName = prefix + "dsn" + dsnFlagHelp := "Connection string for the database (DB password can be specified through PGPASSWORD environment variable)." + if additionalHelp != "" { + dsnFlagHelp += " " + additionalHelp + } + cmd.Flags().StringVar(&c.dsn, c.dsnFlagName, "", dsnFlagHelp) + + c.isEmptyDsnUsingPqFlagName = prefix + "empty-dsn" + isEmptyDsnUsingPqFlagHelp := "Connect with an empty DSN using the pq environment variables and defaults." + if additionalHelp != "" { + isEmptyDsnUsingPqFlagHelp += " " + additionalHelp + } + cmd.Flags().BoolVar(&c.isEmptyDsnUsingPq, c.isEmptyDsnUsingPqFlagName, false, isEmptyDsnUsingPqFlagHelp) - cmd.Flags().StringVar(&flags.dsn, "dsn", "", "Connection string for the database (DB password can be specified through PGPASSWORD environment variable)") - // Don't mark dsn as a required flag. - // Allow users to use the "PGHOST" etc environment variables like `psql`. + return &c +} - return flags +func (c *connectionFlags) IsSet() bool { + return c.dsn != "" || c.isEmptyDsnUsingPq } -func parseConnConfig(c connFlags, logger log.Logger) (*pgx.ConnConfig, error) { - if c.dsn == "" { - logger.Warnf("DSN flag not set. Using libpq environment variables and default values.") +func parseConnectionFlags(flags *connectionFlags) (*pgx.ConnConfig, error) { + if !flags.isEmptyDsnUsingPq && flags.dsn == "" { + return nil, fmt.Errorf("must specify either --%s or --%s", flags.dsnFlagName, flags.isEmptyDsnUsingPqFlagName) } - - return pgx.ParseConfig(c.dsn) + connConfig, err := pgx.ParseConfig(flags.dsn) + if err != nil { + return nil, fmt.Errorf("could not parse connection string %q: %w", flags.dsn, err) + } + return connConfig, nil } // logFmtToMap parses all LogFmt key/value pairs from the provided string into a diff --git a/cmd/pg-schema-diff/main.go b/cmd/pg-schema-diff/main.go index 7fa0213..fac8a4e 100644 --- a/cmd/pg-schema-diff/main.go +++ b/cmd/pg-schema-diff/main.go @@ -6,21 +6,19 @@ import ( "github.com/spf13/cobra" ) -// rootCmd represents the base command when called without any subcommands -var rootCmd = &cobra.Command{ - Use: "pg-schema-diff", - Short: "Diff two Postgres schemas and generate the SQL to get from one to the other", -} - -func init() { +func buildRootCmd() *cobra.Command { + rootCmd := &cobra.Command{ + Use: "pg-schema-diff", + Short: "Diff two Postgres schemas and generate the SQL to get from one to the other", + } rootCmd.AddCommand(buildPlanCmd()) rootCmd.AddCommand(buildApplyCmd()) rootCmd.AddCommand(buildVersionCmd()) + return rootCmd } func main() { - err := rootCmd.Execute() - if err != nil { + if err := buildRootCmd().Execute(); err != nil { os.Exit(1) } } diff --git a/cmd/pg-schema-diff/main_test.go b/cmd/pg-schema-diff/main_test.go new file mode 100644 index 0000000..54d89cd --- /dev/null +++ b/cmd/pg-schema-diff/main_test.go @@ -0,0 +1,143 @@ +package main + +import ( + "bytes" + "database/sql" + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/stripe/pg-schema-diff/internal/pgengine" +) + +type cmdTestSuite struct { + suite.Suite + pgEngine *pgengine.Engine +} + +func (suite *cmdTestSuite) SetupSuite() { + pgEngine, err := pgengine.StartEngine() + suite.Require().NoError(err) + suite.pgEngine = pgEngine +} + +func (suite *cmdTestSuite) TearDownSuite() { + suite.Require().NoError(suite.pgEngine.Close()) +} + +type runCmdWithAssertionsParams struct { + args []string + // dynamicArgs is function that can be used to build args that are dynamic, i.e., + // saving schemas to a randomly generated temporary directory. + dynamicArgs []dArgGenerator + + // outputContains is a list of substrings that are expected to be contained in the stdout output of the command. + outputContains []string + // expectErrContains is a list of substrings that are expected to be contained in the error returned by + // cmd.RunE. This is DISTINCT from stdErr. + expectErrContains []string +} + +func (suite *cmdTestSuite) runCmdWithAssertions(tc runCmdWithAssertionsParams) { + args := tc.args + for _, da := range tc.dynamicArgs { + args = append(args, da(suite.T())...) + } + + rootCmd := buildRootCmd() + rootCmd.SetArgs(args) + stdOut := &bytes.Buffer{} + rootCmd.SetOut(stdOut) + stdErr := &bytes.Buffer{} + rootCmd.SetErr(stdErr) + + err := rootCmd.Execute() + if len(tc.expectErrContains) > 0 { + for _, e := range tc.expectErrContains { + suite.ErrorContains(err, e) + } + } else { + stdErrStr := stdErr.String() + suite.Require().NoError(err) + // Only assert the std error is empty if we don't expect an error + suite.Empty(stdErrStr, "expected no stderr") + } + + stdOutStr := stdOut.String() + if len(tc.outputContains) > 0 { + for _, o := range tc.outputContains { + suite.Contains(stdOutStr, o) + } + } +} + +// dArgGenerator generates argument at the run-time of the test case... +// intended for resources that are not known at test start and potentially need +// to be cleaned up. +type dArgGenerator func(*testing.T) []string + +func tempSchemaDirDArg(argName string, ddl []string) dArgGenerator { + return func(t *testing.T) []string { + t.Helper() + return []string{"--" + argName, tempSchemaDir(t, ddl)} + } +} + +func tempSchemaDir(t *testing.T, ddl []string) string { + t.Helper() + dir := t.TempDir() + for i, d := range ddl { + require.NoError(t, os.WriteFile(filepath.Join(dir, fmt.Sprintf("ddl_%d.sql", i)), []byte(d), 0644)) + } + return dir +} + +func tempDsnDArg(pgEngine *pgengine.Engine, argName string, ddl []string) dArgGenerator { + return func(t *testing.T) []string { + t.Helper() + db := tempDbWithSchema(t, pgEngine, ddl) + return []string{"--" + argName, db.GetDSN()} + } +} + +func tempDbWithSchema(t *testing.T, pgEngine *pgengine.Engine, ddl []string) *pgengine.DB { + t.Helper() + db, err := pgEngine.CreateDatabase() + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, db.DropDB()) + }) + dbPool, err := sql.Open("pgx", db.GetDSN()) + require.NoError(t, err) + defer func() { + require.NoError(t, dbPool.Close()) + }() + for _, d := range ddl { + _, err := dbPool.Exec(d) + require.NoError(t, err) + } + return db +} + +func tempSetPqEnvVarsForDb(t *testing.T, db *pgengine.DB) { + t.Helper() + tempSetEnvVar(t, "PGHOST", db.GetConnOpts()[pgengine.ConnectionOptionHost]) + tempSetEnvVar(t, "PGPORT", db.GetConnOpts()[pgengine.ConnectionOptionPort]) + tempSetEnvVar(t, "PGUSER", db.GetConnOpts()[pgengine.ConnectionOptionUser]) + tempSetEnvVar(t, "PGDATABASE", db.GetName()) +} + +func tempSetEnvVar(t *testing.T, k, v string) { + t.Helper() + require.NoError(t, os.Setenv(k, v)) + t.Cleanup(func() { + require.NoError(t, os.Unsetenv(k)) + }) +} + +func TestCmdTestSuite(t *testing.T) { + suite.Run(t, new(cmdTestSuite)) +} diff --git a/cmd/pg-schema-diff/plan_cmd.go b/cmd/pg-schema-diff/plan_cmd.go index 36eabb2..1efe265 100644 --- a/cmd/pg-schema-diff/plan_cmd.go +++ b/cmd/pg-schema-diff/plan_cmd.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "encoding/json" - "errors" "fmt" "io" "regexp" @@ -14,6 +13,7 @@ import ( "github.com/jackc/pgx/v4" "github.com/spf13/cobra" + "github.com/stripe/pg-schema-diff/internal/util" "github.com/stripe/pg-schema-diff/pkg/diff" "github.com/stripe/pg-schema-diff/pkg/log" "github.com/stripe/pg-schema-diff/pkg/tempdb" @@ -38,28 +38,68 @@ func buildPlanCmd() *cobra.Command { Short: "Generate the diff between two databases and the SQL to get from one to the other", } - connFlags := createConnFlags(cmd) - planFlags := createPlanFlags(cmd) + fromSchemaFlags := createSchemaSourceFlags(cmd, "from-") + toSchemaFlags := createSchemaSourceFlags(cmd, "to-") + tempDbConnFlags := createConnectionFlags(cmd, "temp-db-", "The temporary database to use for schema extraction. This is optional if diffing to/from a Postgres instance") + planOptsFlags := createPlanOptionsFlags(cmd) + outputFmt := outputFormatPretty + cmd.Flags().Var( + &outputFmt, + "output-format", + fmt.Sprintf("Change the output format for what is printed. Defaults to pretty-printed human-readable output. (options: %s)", strings.Join(outputFormatStrings(), ", ")), + ) cmd.RunE = func(cmd *cobra.Command, args []string) error { logger := log.SimpleLogger() - connConfig, err := parseConnConfig(*connFlags, logger) + + fromSchema, err := parseSchemaSource(*fromSchemaFlags) + if err != nil { + return err + } + + toSchema, err := parseSchemaSource(*toSchemaFlags) + if err != nil { + return err + } + + if !tempDbConnFlags.IsSet() { + // A temporary database must be provided. Attempt to pull it from the from or to schema source. + if fromSchemaFlags.connFlags.IsSet() { + tempDbConnFlags = fromSchemaFlags.connFlags + } else if toSchemaFlags.connFlags.IsSet() { + tempDbConnFlags = toSchemaFlags.connFlags + } else { + // In the future, we may allow folks to plumb in a postgres binary that we start for them OR a separate + // flag that allows them to specify a temporary database DSN> + // + // Notably, a temporary database is NOT required if both databases are DSNs..., but inherently that means + // we can derive a tempdDbDsn (this case is never hit). + return fmt.Errorf("at least one Postgres server must be provided to generate a plan. either --%s, --%s or --%s must be set. Without a temporary Postgres database, pg-schema-diff cannot extract the schema from DDL", tempDbConnFlags.dsnFlagName, fromSchemaFlags.connFlags.dsnFlagName, toSchemaFlags.connFlags.dsnFlagName) + } + } + tempDbConnConfig, err := parseConnectionFlags(tempDbConnFlags) if err != nil { return err } - planConfig, err := parsePlanConfig(*planFlags) + planOpts, err := parsePlanOptions(*planOptsFlags) if err != nil { return err } cmd.SilenceUsage = true - plan, err := generatePlan(context.Background(), logger, connConfig, planConfig) + plan, err := generatePlan(cmd.Context(), generatePlanParameters{ + fromSchema: fromSchema, + toSchema: toSchema, + tempDbConnConfig: tempDbConnConfig, + planOptions: planOpts, + logger: logger, + }) if err != nil { return err } - fmt.Println(planFlags.outputFormat.convertToOutputString(plan)) + cmd.Println(outputFmt.convertToOutputString(plan)) return nil } @@ -67,25 +107,10 @@ func buildPlanCmd() *cobra.Command { } type ( - schemaFlags struct { + // parsePlanOptionsFlags stores the flags that are parsed into planOptions. + planOptionsFlags struct { includeSchemas []string excludeSchemas []string - } - - schemaSourceFlags struct { - schemaDirs []string - targetDatabaseDSN string - } - - outputFormat struct { - identifier string - convertToOutputString func(diff.Plan) string - } - - planFlags struct { - dbSchemaSourceFlags schemaSourceFlags - - schemaFlags schemaFlags dataPackNewTables bool disablePlanValidation bool @@ -93,7 +118,11 @@ type ( statementTimeoutModifiers []string lockTimeoutModifiers []string insertStatements []string - outputFormat outputFormat + } + + outputFormat struct { + identifier string + convertToOutputString func(diff.Plan) string } timeoutModifier struct { @@ -108,51 +137,81 @@ type ( lockTimeout time.Duration } - schemaSourceFactory func() (diff.SchemaSource, io.Closer, error) - - planConfig struct { - schemaSourceFactory schemaSourceFactory - opts []diff.PlanOpt - + // planOptions stores options that are plumbed into plan generation process and dictate post-plan processing. + planOptions struct { + opts []diff.PlanOpt statementTimeoutModifiers []timeoutModifier lockTimeoutModifiers []timeoutModifier insertStatements []insertStatement } + + // schemaSourceFactoryFlags stores the flags that are parsed into a schemaSourceFactory. + schemaSourceFactoryFlags struct { + // schemaDirs should be provided if the schema is defined via SQL files. + schemaDirs []string + schemaDirFlagName string + + // connFlags should be provided if the schema is defined through a database. + connFlags *connectionFlags + } + + // schemaSourceFactory provides a layer of indirection such that all database opening and closing can be done + // in a single place, i.e., in the plan generation function. It also enables schema source flag parsing to return + // errors while SilenceUsage=true, and database connection opening to have SilenceUsage=false. + schemaSourceFactory func() (diff.SchemaSource, io.Closer, error) ) var ( - outputFormatPretty outputFormat = outputFormat{identifier: "pretty", convertToOutputString: planToPrettyS} - outputFormatJson outputFormat = outputFormat{identifier: "json", convertToOutputString: planToJsonS} + outputFormatPretty = outputFormat{ + identifier: "pretty", + convertToOutputString: planToPrettyS, + } + + outputFormatJson = outputFormat{ + identifier: "json", + convertToOutputString: planToJsonS, + } + + outputFormats = []outputFormat{ + outputFormatPretty, + outputFormatJson, + } + + outputFormatStrings = func() []string { + var options []string + for _, format := range outputFormats { + options = append(options, format.identifier) + } + return options + + } ) func (e *outputFormat) String() string { - return string(e.identifier) + return e.identifier } func (e *outputFormat) Set(v string) error { - switch v { - case "pretty": - *e = outputFormatPretty - return nil - case "json": - *e = outputFormatJson - return nil - default: - return errors.New(`must be one of "pretty" or "json"`) + var options []string + for _, format := range outputFormats { + if format.identifier == v { + *e = format + return nil + } + options = append(options, format.identifier) } + return fmt.Errorf("invalid output format %q. Options are: %s", v, strings.Join(options, ", ")) } func (e *outputFormat) Type() string { return "outputFormat" } -func createPlanFlags(cmd *cobra.Command) *planFlags { - flags := &planFlags{} - flags.outputFormat = outputFormatPretty +func createPlanOptionsFlags(cmd *cobra.Command) *planOptionsFlags { + var flags planOptionsFlags - schemaSourceFlagsVar(cmd, &flags.dbSchemaSourceFlags) - - schemaFlagsVar(cmd, &flags.schemaFlags) + cmd.Flags().StringArrayVar(&flags.includeSchemas, "include-schema", nil, "Include the specified schema in the plan") + cmd.Flags().StringArrayVar(&flags.excludeSchemas, "exclude-schema", nil, "Exclude the specified schema in the plan") cmd.Flags().BoolVar(&flags.dataPackNewTables, "data-pack-new-tables", true, "If set, will data pack new tables in the plan to minimize table size (re-arranges columns).") cmd.Flags().BoolVar(&flags.disablePlanValidation, "disable-plan-validation", false, "If set, will disable plan validation. Plan validation runs the migration against a temporary"+ @@ -171,24 +230,21 @@ func createPlanFlags(cmd *cobra.Command) *planFlags { ), ) - cmd.Flags().Var(&flags.outputFormat, "output-format", "Change the output format for what is printed. Defaults to pretty-printed human-readable output. (options: pretty, json)") - - return flags + return &flags } -func schemaSourceFlagsVar(cmd *cobra.Command, p *schemaSourceFlags) { - cmd.Flags().StringArrayVar(&p.schemaDirs, "schema-dir", nil, "Directory of .SQL files to use as the schema source (can be multiple). Use to generate a diff between the target database and the schema in this directory.") - if err := cmd.MarkFlagDirname("schema-dir"); err != nil { +func createSchemaSourceFlags(cmd *cobra.Command, prefix string) *schemaSourceFactoryFlags { + var p schemaSourceFactoryFlags + + p.schemaDirFlagName = prefix + "dir" + cmd.Flags().StringArrayVar(&p.schemaDirs, p.schemaDirFlagName, nil, "Directory of .SQL files to use as the schema source (can be multiple).") + if err := cmd.MarkFlagDirname(p.schemaDirFlagName); err != nil { panic(err) } - cmd.Flags().StringVar(&p.targetDatabaseDSN, "schema-source-dsn", "", "DSN for the database to use as the schema source. Use to generate a diff between the target database and the schema in this database.") - cmd.MarkFlagsMutuallyExclusive("schema-dir", "schema-source-dsn") -} + p.connFlags = createConnectionFlags(cmd, prefix, " The database to use as the schema source") -func schemaFlagsVar(cmd *cobra.Command, p *schemaFlags) { - cmd.Flags().StringArrayVar(&p.includeSchemas, "include-schema", nil, "Include the specified schema in the plan") - cmd.Flags().StringArrayVar(&p.excludeSchemas, "exclude-schema", nil, "Exclude the specified schema in the plan") + return &p } func timeoutModifierFlagVar(cmd *cobra.Command, p *[]string, timeoutType string, shorthand string) { @@ -203,13 +259,56 @@ func timeoutModifierFlagVar(cmd *cobra.Command, p *[]string, timeoutType string, cmd.Flags().StringArrayVarP(p, flagName, shorthand, nil, description) } -func parsePlanConfig(p planFlags) (planConfig, error) { - schemaSourceFactory, err := parseSchemaSource(p.dbSchemaSourceFlags) - if err != nil { - return planConfig{}, err +func parseSchemaSource(p schemaSourceFactoryFlags) (schemaSourceFactory, error) { + // Store result in a var instead of returning early to ensure only one option is set. + var ssf schemaSourceFactory + + if len(p.schemaDirs) > 0 { + ssf = func() (diff.SchemaSource, io.Closer, error) { + schemaSource, err := diff.DirSchemaSource(p.schemaDirs) + if err != nil { + return nil, nil, err + } + return schemaSource, util.NoOpCloser(), nil + } + } + + if p.connFlags.IsSet() { + if ssf != nil { + return nil, fmt.Errorf("only one of --%s or --%s can be set", p.schemaDirFlagName, p.connFlags.dsnFlagName) + } + connConfig, err := parseConnectionFlags(p.connFlags) + if err != nil { + return nil, err + } + ssf = dsnSchemaSource(connConfig) + } + + if ssf == nil { + return nil, fmt.Errorf("either --%s or --%s must be set", p.schemaDirFlagName, p.connFlags.dsnFlagName) + } + return ssf, nil +} + +// dsnSchemaSource returns a schema source factory that connects to a database using the provided DSN. +// This exists in its own function to allow for the plan cmd to call it. +func dsnSchemaSource(connConfig *pgx.ConnConfig) schemaSourceFactory { + return func() (diff.SchemaSource, io.Closer, error) { + connPool, err := openDbWithPgxConfig(connConfig) + if err != nil { + return nil, nil, fmt.Errorf("opening db with pgx config: %w", err) + } + connPool.SetMaxOpenConns(defaultMaxConnections) + return diff.DBSchemaSource(connPool), connPool, nil + } +} + +func parsePlanOptions(p planOptionsFlags) (planOptions, error) { + opts := []diff.PlanOpt{ + diff.WithIncludeSchemas(p.includeSchemas...), + diff.WithExcludeSchemas(p.excludeSchemas...), } - opts := parseSchemaConfig(p.schemaFlags) if p.dataPackNewTables { opts = append(opts, diff.WithDataPackNewTables()) } @@ -221,7 +320,7 @@ func parsePlanConfig(p planFlags) (planConfig, error) { for _, s := range p.statementTimeoutModifiers { stm, err := parseTimeoutModifier(s) if err != nil { - return planConfig{}, fmt.Errorf("parsing statement timeout modifier from %q: %w", s, err) + return planOptions{}, fmt.Errorf("parsing statement timeout modifier from %q: %w", s, err) } statementTimeoutModifiers = append(statementTimeoutModifiers, stm) } @@ -230,7 +329,7 @@ func parsePlanConfig(p planFlags) (planConfig, error) { for _, s := range p.lockTimeoutModifiers { ltm, err := parseTimeoutModifier(s) if err != nil { - return planConfig{}, fmt.Errorf("parsing statement timeout modifier from %q: %w", s, err) + return planOptions{}, fmt.Errorf("parsing statement timeout modifier from %q: %w", s, err) } lockTimeoutModifiers = append(lockTimeoutModifiers, ltm) } @@ -239,13 +338,12 @@ func parsePlanConfig(p planFlags) (planConfig, error) { for _, i := range p.insertStatements { is, err := parseInsertStatementStr(i) if err != nil { - return planConfig{}, fmt.Errorf("parsing insert statement from %q: %w", i, err) + return planOptions{}, fmt.Errorf("parsing insert statement from %q: %w", i, err) } insertStatements = append(insertStatements, is) } - return planConfig{ - schemaSourceFactory: schemaSourceFactory, + return planOptions{ opts: opts, statementTimeoutModifiers: statementTimeoutModifiers, lockTimeoutModifiers: lockTimeoutModifiers, @@ -253,41 +351,6 @@ func parsePlanConfig(p planFlags) (planConfig, error) { }, nil } -func parseSchemaSource(p schemaSourceFlags) (schemaSourceFactory, error) { - if len(p.schemaDirs) > 0 { - return func() (diff.SchemaSource, io.Closer, error) { - schemaSource, err := diff.DirSchemaSource(p.schemaDirs) - if err != nil { - return nil, nil, err - } - return schemaSource, nil, nil - }, nil - } - - if p.targetDatabaseDSN != "" { - connConfig, err := pgx.ParseConfig(p.targetDatabaseDSN) - if err != nil { - return nil, fmt.Errorf("parsing DSN %q: %w", p.targetDatabaseDSN, err) - } - return func() (diff.SchemaSource, io.Closer, error) { - connPool, err := openDbWithPgxConfig(connConfig) - if err != nil { - return nil, nil, fmt.Errorf("opening db with pgx config: %w", err) - } - return diff.DBSchemaSource(connPool), connPool, nil - }, nil - } - - return nil, fmt.Errorf("either --schema-dir or --schema-source-dsn must be set") -} - -func parseSchemaConfig(p schemaFlags) []diff.PlanOpt { - return []diff.PlanOpt{ - diff.WithIncludeSchemas(p.includeSchemas...), - diff.WithExcludeSchemas(p.excludeSchemas...), - } -} - // parseTimeoutModifier attempts to parse an option representing a statement timeout modifier in the // form of regex=duration where duration could be a decimal number and ends with a unit func parseTimeoutModifier(val string) (timeoutModifier, error) { @@ -379,42 +442,48 @@ func parseInsertStatementStr(val string) (insertStatement, error) { }, nil } -func generatePlan(ctx context.Context, logger log.Logger, connConfig *pgx.ConnConfig, planConfig planConfig) (diff.Plan, error) { +type generatePlanParameters struct { + fromSchema schemaSourceFactory + toSchema schemaSourceFactory + tempDbConnConfig *pgx.ConnConfig + planOptions planOptions + logger log.Logger +} + +func generatePlan( + ctx context.Context, + params generatePlanParameters, +) (diff.Plan, error) { tempDbFactory, err := tempdb.NewOnInstanceFactory(ctx, func(ctx context.Context, dbName string) (*sql.DB, error) { - copiedConfig := connConfig.Copy() - copiedConfig.Database = dbName - return openDbWithPgxConfig(copiedConfig) - }, tempdb.WithRootDatabase(connConfig.Database)) + cfg := params.tempDbConnConfig.Copy() + cfg.Database = dbName + return openDbWithPgxConfig(cfg) + }, tempdb.WithRootDatabase(params.tempDbConnConfig.Database)) if err != nil { - return diff.Plan{}, err + return diff.Plan{}, fmt.Errorf("creating temp db factory: %w", err) } defer func() { err := tempDbFactory.Close() if err != nil { - logger.Errorf("error shutting down temp db factory: %v", err) + params.logger.Errorf("error shutting down temp db factory: %v", err) } }() - connPool, err := openDbWithPgxConfig(connConfig) + fromSchema, fromSchemaSourceCloser, err := params.fromSchema() if err != nil { - return diff.Plan{}, err + return diff.Plan{}, fmt.Errorf("creating schema source: %w", err) } - defer connPool.Close() - connPool.SetMaxOpenConns(defaultMaxConnections) + defer fromSchemaSourceCloser.Close() - schemaSource, schemaSourceCloser, err := planConfig.schemaSourceFactory() + toSchema, toSchemaSourceCloser, err := params.toSchema() if err != nil { return diff.Plan{}, fmt.Errorf("creating schema source: %w", err) } - if schemaSourceCloser != nil { - defer schemaSourceCloser.Close() - } - - connSource := diff.DBSchemaSource(connPool) + defer toSchemaSourceCloser.Close() - plan, err := diff.Generate(ctx, connSource, schemaSource, + plan, err := diff.Generate(ctx, fromSchema, toSchema, append( - planConfig.opts, + params.planOptions.opts, diff.WithTempDbFactory(tempDbFactory), )..., ) @@ -424,7 +493,7 @@ func generatePlan(ctx context.Context, logger log.Logger, connConfig *pgx.ConnCo modifiedPlan, err := applyPlanModifiers( plan, - planConfig, + params.planOptions, ) if err != nil { return diff.Plan{}, fmt.Errorf("applying plan modifiers: %w", err) @@ -435,7 +504,7 @@ func generatePlan(ctx context.Context, logger log.Logger, connConfig *pgx.ConnCo func applyPlanModifiers( plan diff.Plan, - config planConfig, + config planOptions, ) (diff.Plan, error) { for _, stm := range config.statementTimeoutModifiers { plan = plan.ApplyStatementTimeoutModifier(stm.regex, stm.timeout) diff --git a/cmd/pg-schema-diff/plan_cmd_test.go b/cmd/pg-schema-diff/plan_cmd_test.go index b323411..c34f656 100644 --- a/cmd/pg-schema-diff/plan_cmd_test.go +++ b/cmd/pg-schema-diff/plan_cmd_test.go @@ -4,15 +4,108 @@ import ( "regexp" "testing" "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -func TestParseTimeoutModifierStr(t *testing.T) { - for _, tc := range []struct { - opt string `explicit:"always"` +func (suite *cmdTestSuite) TestPlanCmd() { + type testCase struct { + name string + args []string + dynamicArgs []dArgGenerator + // outputContains is a list of substrings that are expected to be contained in the stdout output of the command. + outputContains []string + // expectErrContains is a list of substrings that are expected to be contained in the error returned by + // cmd.RunE. This is DISTINCT from stdErr. + expectErrContains []string + } + // Non-comprehensive set of tests for the plan command. Not totally comprehensive to avoid needing to avoid + // hindering developer velocity when updating the command. + for _, tc := range []testCase{ + { + name: "from dsn to dsn", + dynamicArgs: []dArgGenerator{ + tempDsnDArg(suite.pgEngine, "from-dsn", nil), + tempDsnDArg(suite.pgEngine, "to-dsn", []string{"CREATE TABLE foobar()"}), + }, + outputContains: []string{"CREATE TABLE"}, + }, + { + name: "from dsn to dir", + dynamicArgs: []dArgGenerator{ + tempDsnDArg(suite.pgEngine, "from-dsn", []string{""}), + tempSchemaDirDArg("to-dir", []string{"CREATE TABLE foobar()"}), + }, + outputContains: []string{"CREATE TABLE"}, + }, + { + name: "from dir to dsn", + dynamicArgs: []dArgGenerator{ + tempSchemaDirDArg("from-dir", nil), + tempDsnDArg(suite.pgEngine, "to-dsn", []string{"CREATE TABLE foobar()"}), + }, + outputContains: []string{"CREATE TABLE"}, + }, + { + name: "from dir to dir", + dynamicArgs: []dArgGenerator{ + tempSchemaDirDArg("from-dir", nil), + tempSchemaDirDArg("to-dir", []string{"CREATE TABLE foobar()"}), + tempDsnDArg(suite.pgEngine, "temp-db-dsn", []string{""}), + }, + outputContains: []string{"CREATE TABLE"}, + }, + { + name: "from empty dsn to dir", + dynamicArgs: []dArgGenerator{ + func(t *testing.T) []string { + db := tempDbWithSchema(t, suite.pgEngine, []string{""}) + tempSetPqEnvVarsForDb(t, db) + return []string{"--from-empty-dsn"} + }, + tempSchemaDirDArg("to-dir", []string{"CREATE TABLE foobar()"}), + }, + outputContains: []string{"CREATE TABLE"}, + }, + { + name: "no from schema provided", + args: []string{"--to-dir", "some-other-dir"}, + expectErrContains: []string{"must be set"}, + }, + { + name: "no to schema provided", + args: []string{"--from-dir", "some-other-dir"}, + expectErrContains: []string{"must be set"}, + }, + { + name: "two from schemas provided", + args: []string{"--from-dir", "some-dir", "--from-dsn", "some-dsn", "--to-dir", "some-other-dir"}, + expectErrContains: []string{"only one of"}, + }, + { + name: "two to schemas provided", + args: []string{"--from-dir", "some-dir", "--to-dir", "some-other-dir", "--to-dsn", "some-dsn"}, + expectErrContains: []string{"only one of"}, + }, + { + name: "no postgres server provided", + args: []string{"--from-dir", "some-dir", "--to-dir", "some-other-dir"}, + expectErrContains: []string{"at least one Postgres server"}, + }, + } { + suite.Run(tc.name, func() { + suite.runCmdWithAssertions(runCmdWithAssertionsParams{ + args: append([]string{"plan"}, tc.args...), + dynamicArgs: tc.dynamicArgs, + outputContains: tc.outputContains, + expectErrContains: tc.expectErrContains, + }) + }) + } +} + +func (suite *cmdTestSuite) TestParseTimeoutModifierStr() { + for _, tc := range []struct { + opt string `explicit:"always"` expected timeoutModifier expectedErrContains string }{ @@ -51,19 +144,19 @@ func TestParseTimeoutModifierStr(t *testing.T) { expectedErrContains: "pattern regex could not be compiled", }, } { - t.Run(tc.opt, func(t *testing.T) { + suite.Run(tc.opt, func() { modifier, err := parseTimeoutModifier(tc.opt) if len(tc.expectedErrContains) > 0 { - assert.ErrorContains(t, err, tc.expectedErrContains) + suite.ErrorContains(err, tc.expectedErrContains) return } - require.NoError(t, err) - assert.Equal(t, tc.expected, modifier) + suite.Require().NoError(err) + suite.Equal(tc.expected, modifier) }) } } -func TestParseInsertStatementStr(t *testing.T) { +func (suite *cmdTestSuite) TestParseInsertStatementStr() { for _, tc := range []struct { opt string `explicit:"always"` expectedInsertStmt insertStatement @@ -107,14 +200,14 @@ func TestParseInsertStatementStr(t *testing.T) { expectedErrContains: "lock timeout duration could not be parsed", }, } { - t.Run(tc.opt, func(t *testing.T) { + suite.Run(tc.opt, func() { insertStatement, err := parseInsertStatementStr(tc.opt) if len(tc.expectedErrContains) > 0 { - assert.ErrorContains(t, err, tc.expectedErrContains) + suite.ErrorContains(err, tc.expectedErrContains) return } - require.NoError(t, err) - assert.Equal(t, tc.expectedInsertStmt, insertStatement) + suite.Require().NoError(err) + suite.Equal(tc.expectedInsertStmt, insertStatement) }) } } diff --git a/cmd/pg-schema-diff/version_cmd.go b/cmd/pg-schema-diff/version_cmd.go index d3b56b1..18e879e 100644 --- a/cmd/pg-schema-diff/version_cmd.go +++ b/cmd/pg-schema-diff/version_cmd.go @@ -17,7 +17,7 @@ func buildVersionCmd() *cobra.Command { if !ok { return fmt.Errorf("build information not available") } - fmt.Printf("version=%s\n", buildInfo.Main.Version) + cmd.Printf("version=%s\n", buildInfo.Main.Version) return nil } diff --git a/internal/pgengine/engine.go b/internal/pgengine/engine.go index a0cd553..01b958b 100644 --- a/internal/pgengine/engine.go +++ b/internal/pgengine/engine.go @@ -18,7 +18,10 @@ import ( type ConnectionOption string const ( + ConnectionOptionHost ConnectionOption = "host" + ConnectionOptionUser ConnectionOption = "user" ConnectionOptionDatabase ConnectionOption = "dbname" + ConnectionOptionPort ConnectionOption = "port" ) type ConnectionOptions map[ConnectionOption]string @@ -190,9 +193,9 @@ func (e *Engine) testIfInstanceServingTraffic() error { func (e *Engine) GetPostgresDatabaseConnOpts() ConnectionOptions { result := make(map[ConnectionOption]string) result[ConnectionOptionDatabase] = "postgres" - result["host"] = e.sockPath - result["user"] = e.superuser - result["port"] = strconv.Itoa(defaultPort) + result[ConnectionOptionHost] = e.sockPath + result[ConnectionOptionUser] = e.superuser + result[ConnectionOptionPort] = strconv.Itoa(defaultPort) result["sslmode"] = "disable" return result diff --git a/internal/util/closer.go b/internal/util/closer.go index 9bece01..5fcc0d2 100644 --- a/internal/util/closer.go +++ b/internal/util/closer.go @@ -1,5 +1,16 @@ package util +import "io" + +type noOpCloser struct{} + +func (noOpCloser) Close() error { return nil } + +// NoOpCloser returns a Closer that does nothing. +func NoOpCloser() io.Closer { + return noOpCloser{} +} + // DoOnErrOrPanic calls f if the value of err is not nil or if the goroutine is // panicking. If there is a panic, it is rethrown. // diff --git a/internal/util/closer_test.go b/internal/util/closer_test.go index 0b67799..1c64ea7 100644 --- a/internal/util/closer_test.go +++ b/internal/util/closer_test.go @@ -8,7 +8,11 @@ import ( "github.com/stripe/pg-schema-diff/internal/util" ) -func Test_DoOnErrOrPanicIsCalledOnError(t *testing.T) { +func TestNoOpCloser(t *testing.T) { + require.NoError(t, util.NoOpCloser().Close()) +} + +func TestDoOnErrOrPanicIsCalledOnError(t *testing.T) { var err error wasCalled := false defer func() { @@ -22,7 +26,7 @@ func Test_DoOnErrOrPanicIsCalledOnError(t *testing.T) { return } -func Test_DoOnErrOrPanicIsNotCalledOnNoError(t *testing.T) { +func TestDoOnErrOrPanicIsNotCalledOnNoError(t *testing.T) { var err error wasCalled := false defer func() { @@ -35,7 +39,7 @@ func Test_DoOnErrOrPanicIsNotCalledOnNoError(t *testing.T) { return } -func Test_DoOnErrOrPanicIsCalledOnPanic(t *testing.T) { +func TestDoOnErrOrPanicIsCalledOnPanic(t *testing.T) { var err error wasCalled := false defer func() {