diff --git a/cmd/pg-schema-diff/plan_cmd.go b/cmd/pg-schema-diff/plan_cmd.go index 1efe265..1b592ef 100644 --- a/cmd/pg-schema-diff/plan_cmd.go +++ b/cmd/pg-schema-diff/plan_cmd.go @@ -172,9 +172,15 @@ var ( convertToOutputString: planToJsonS, } + outputFormatSql = outputFormat{ + identifier: "sql", + convertToOutputString: planToSqlS, + } + outputFormats = []outputFormat{ outputFormatPretty, outputFormatJson, + outputFormatSql, } outputFormatStrings = func() []string { @@ -591,3 +597,19 @@ func planToJsonS(plan diff.Plan) string { } return string(jsonData) } + +func planToSqlS(plan diff.Plan) string { + sb := strings.Builder{} + + if len(plan.Statements) == 0 { + return "" + } + + var stmtStrs []string + for _, stmt := range plan.Statements { + stmtStrs = append(stmtStrs, statementToPrettyS(stmt)) + } + sb.WriteString(strings.Join(stmtStrs, "\n\n")) + + return sb.String() +} diff --git a/cmd/pg-schema-diff/plan_cmd_test.go b/cmd/pg-schema-diff/plan_cmd_test.go index c34f656..5daf605 100644 --- a/cmd/pg-schema-diff/plan_cmd_test.go +++ b/cmd/pg-schema-diff/plan_cmd_test.go @@ -1,9 +1,13 @@ package main import ( + "database/sql" "regexp" + "strings" "testing" "time" + + "github.com/stripe/pg-schema-diff/pkg/diff" ) func (suite *cmdTestSuite) TestPlanCmd() { @@ -91,6 +95,45 @@ func (suite *cmdTestSuite) TestPlanCmd() { args: []string{"--from-dir", "some-dir", "--to-dir", "some-other-dir"}, expectErrContains: []string{"at least one Postgres server"}, }, + { + name: "sql output format - from dsn to dsn", + args: []string{"--output-format", "sql"}, + dynamicArgs: []dArgGenerator{ + tempDsnDArg(suite.pgEngine, "from-dsn", nil), + tempDsnDArg(suite.pgEngine, "to-dsn", []string{"CREATE TABLE foobar()"}), + }, + outputContains: []string{"CREATE TABLE \"public\".\"foobar\"", ";"}, + }, + { + name: "sql output format - from dsn to dir", + args: []string{"--output-format", "sql"}, + dynamicArgs: []dArgGenerator{ + tempDsnDArg(suite.pgEngine, "from-dsn", []string{""}), + tempSchemaDirDArg("to-dir", []string{"CREATE TABLE foobar()"}), + }, + outputContains: []string{"CREATE TABLE \"public\".\"foobar\"", ";"}, + }, + { + name: "sql output format - multiple statements", + args: []string{"--output-format", "sql"}, + dynamicArgs: []dArgGenerator{ + tempDsnDArg(suite.pgEngine, "from-dsn", nil), + tempDsnDArg(suite.pgEngine, "to-dsn", []string{ + "CREATE TABLE table1()", + "CREATE TABLE table2()", + }), + }, + outputContains: []string{"CREATE TABLE \"public\".\"table1\"", "CREATE TABLE \"public\".\"table2\"", ";"}, + }, + { + name: "invalid output format", + args: []string{"--output-format", "invalid"}, + dynamicArgs: []dArgGenerator{ + tempDsnDArg(suite.pgEngine, "from-dsn", nil), + tempDsnDArg(suite.pgEngine, "to-dsn", []string{"CREATE TABLE foobar()"}), + }, + expectErrContains: []string{"invalid output format"}, + }, } { suite.Run(tc.name, func() { suite.runCmdWithAssertions(runCmdWithAssertionsParams{ @@ -211,3 +254,243 @@ func (suite *cmdTestSuite) TestParseInsertStatementStr() { }) } } + +func TestPlanToSqlS(t *testing.T) { + testCases := []struct { + name string + plan diff.Plan + expected string + }{ + { + name: "empty plan", + plan: diff.Plan{Statements: []diff.Statement{}}, + expected: "", + }, + { + name: "single statement", + plan: diff.Plan{ + Statements: []diff.Statement{ + {DDL: "CREATE TABLE test ()"}, + }, + }, + expected: "CREATE TABLE test ();\n\t-- Statement Timeout: 0s", + }, + { + name: "multiple statements", + plan: diff.Plan{ + Statements: []diff.Statement{ + {DDL: "CREATE TABLE test1 ()"}, + {DDL: "CREATE TABLE test2 ()"}, + }, + }, + expected: "CREATE TABLE test1 ();\n\t-- Statement Timeout: 0s\n\nCREATE TABLE test2 ();\n\t-- Statement Timeout: 0s", + }, + { + name: "statements with comments and timeouts should be included", + plan: diff.Plan{ + Statements: []diff.Statement{ + { + DDL: "CREATE INDEX CONCURRENTLY idx_test ON test (col)", + Timeout: time.Minute * 5, + Hazards: []diff.MigrationHazard{ + {Type: "SOME_HAZARD", Message: "This is dangerous"}, + }, + }, + }, + }, + expected: "CREATE INDEX CONCURRENTLY idx_test ON test (col);\n\t-- Statement Timeout: 5m0s\n\t-- Hazard SOME_HAZARD: This is dangerous", + }, + { + name: "statements already ending with semicolon", + plan: diff.Plan{ + Statements: []diff.Statement{ + {DDL: "CREATE TABLE test1 ();"}, + {DDL: "ALTER TABLE test SET DATA TYPE integer;"}, + }, + }, + expected: "CREATE TABLE test1 ();;\n\t-- Statement Timeout: 0s\n\nALTER TABLE test SET DATA TYPE integer;;\n\t-- Statement Timeout: 0s", + }, + { + name: "mixed statements with and without semicolons", + plan: diff.Plan{ + Statements: []diff.Statement{ + {DDL: "CREATE TABLE test1 ()"}, + {DDL: "ALTER TABLE test SET DATA TYPE integer;"}, + {DDL: "DROP TABLE test2"}, + }, + }, + expected: "CREATE TABLE test1 ();\n\t-- Statement Timeout: 0s\n\nALTER TABLE test SET DATA TYPE integer;;\n\t-- Statement Timeout: 0s\n\nDROP TABLE test2;\n\t-- Statement Timeout: 0s", + }, + { + name: "statements with trailing whitespace", + plan: diff.Plan{ + Statements: []diff.Statement{ + {DDL: "CREATE TABLE test () "}, + {DDL: "ALTER TABLE test ADD COLUMN id int; \n"}, + }, + }, + expected: "CREATE TABLE test () ;\n\t-- Statement Timeout: 0s\n\nALTER TABLE test ADD COLUMN id int; \n;\n\t-- Statement Timeout: 0s", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := planToSqlS(tc.plan) + if result != tc.expected { + t.Errorf("Expected:\n%s\nGot:\n%s", tc.expected, result) + } + }) + } +} + +func TestOutputFormatValidation(t *testing.T) { + testCases := []struct { + name string + formatStr string + expectError bool + expectedValue outputFormat + }{ + { + name: "valid pretty format", + formatStr: "pretty", + expectError: false, + expectedValue: outputFormatPretty, + }, + { + name: "valid json format", + formatStr: "json", + expectError: false, + expectedValue: outputFormatJson, + }, + { + name: "valid sql format", + formatStr: "sql", + expectError: false, + expectedValue: outputFormatSql, + }, + { + name: "invalid format", + formatStr: "invalid", + expectError: true, + }, + { + name: "empty format", + formatStr: "", + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var format outputFormat + err := format.Set(tc.formatStr) + + if tc.expectError { + if err == nil { + t.Errorf("Expected error for format '%s', but got none", tc.formatStr) + } + } else { + if err != nil { + t.Errorf("Unexpected error for format '%s': %v", tc.formatStr, err) + } + if format.identifier != tc.expectedValue.identifier { + t.Errorf("Expected identifier '%s', got '%s'", tc.expectedValue.identifier, format.identifier) + } + } + }) + } +} + +func TestSqlFormatDoesNotContainHeaders(t *testing.T) { + // Test that SQL format doesn't contain formatting headers like "####" or "1." + plan := diff.Plan{ + Statements: []diff.Statement{ + {DDL: "CREATE TABLE test_table (id int)"}, + }, + } + + result := planToSqlS(plan) + + // Check that the result doesn't contain pretty format markers + forbiddenStrings := []string{"####", "Generated plan", "1.", "2.", "3."} + for _, forbidden := range forbiddenStrings { + if strings.Contains(result, forbidden) { + t.Errorf("SQL format output should not contain '%s', but found it in: %s", forbidden, result) + } + } + + // Check that it contains proper SQL + if !strings.Contains(result, "CREATE TABLE test_table") { + t.Errorf("SQL format should contain the actual SQL statement") + } + if !strings.Contains(result, ";") { + t.Errorf("SQL format should end statements with semicolon") + } +} + +func (suite *cmdTestSuite) TestSqlOutputExecutable() { + // End-to-end test to verify that generated SQL can actually be executed + // Create source and target databases + sourceDb := tempDbWithSchema(suite.T(), suite.pgEngine, []string{ + "CREATE TABLE users (id int PRIMARY KEY)", + }) + targetDb := tempDbWithSchema(suite.T(), suite.pgEngine, []string{ + "CREATE TABLE users (id int PRIMARY KEY)", + "CREATE TABLE posts (id int PRIMARY KEY, user_id int REFERENCES users(id))", + }) + + // Create a third database to test the generated SQL + testDb := tempDbWithSchema(suite.T(), suite.pgEngine, []string{ + "CREATE TABLE users (id int PRIMARY KEY)", + }) + + // Generate SQL using our new format + args := []string{ + "plan", + "--output-format", "sql", + "--from-dsn", sourceDb.GetDSN(), + "--to-dsn", targetDb.GetDSN(), + } + + rootCmd := buildRootCmd() + rootCmd.SetArgs(args) + var sqlOutput strings.Builder + rootCmd.SetOut(&sqlOutput) + rootCmd.SetErr(&strings.Builder{}) + + err := rootCmd.Execute() + suite.Require().NoError(err) + + generatedSQL := sqlOutput.String() + suite.T().Logf("Generated SQL: %s", generatedSQL) + + // Verify the SQL is not empty and contains expected content + suite.Assert().NotEmpty(generatedSQL) + suite.Assert().Contains(generatedSQL, "CREATE TABLE") + suite.Assert().Contains(generatedSQL, "posts") + + // Now try to execute the generated SQL against the test database + conn, err := sql.Open("pgx", testDb.GetDSN()) + suite.Require().NoError(err) + defer conn.Close() + + // Split SQL by semicolons and execute each statement + sqlStatements := strings.Split(strings.TrimSpace(generatedSQL), ";") + for _, stmt := range sqlStatements { + stmt = strings.TrimSpace(stmt) + if stmt == "" { + continue + } + // Add semicolon back for execution + stmt += ";" + suite.T().Logf("Executing SQL: %s", stmt) + _, err := conn.Exec(stmt) + suite.Require().NoError(err, "Failed to execute generated SQL statement: %s", stmt) + } + + // Verify that the posts table was created successfully + var tableName string + err = conn.QueryRow("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'posts'").Scan(&tableName) + suite.Require().NoError(err) + suite.Assert().Equal("posts", tableName) +}