diff --git a/internal/migration_acceptance_tests/procedure_cases_test.go b/internal/migration_acceptance_tests/procedure_cases_test.go new file mode 100644 index 0000000..7a92a98 --- /dev/null +++ b/internal/migration_acceptance_tests/procedure_cases_test.go @@ -0,0 +1,164 @@ +package migration_acceptance_tests + +import "github.com/stripe/pg-schema-diff/pkg/diff" + +var procedureAcceptanceTestCases = []acceptanceTestCase{ + { + name: "No-op", + oldSchemaDDL: []string{ + ` + CREATE OR REPLACE PROCEDURE some_procedure(i integer) AS $$ + BEGIN + RAISE NOTICE 'foobar'; + END; + $$ LANGUAGE plpgsql; + `, + }, + newSchemaDDL: []string{ + ` + CREATE OR REPLACE PROCEDURE some_procedure(i integer) AS $$ + BEGIN + RAISE NOTICE 'foobar'; + END; + $$ LANGUAGE plpgsql; + `, + }, + + expectEmptyPlan: true, + }, + { + name: "Create procedure with no dependencies", + oldSchemaDDL: nil, + newSchemaDDL: []string{ + ` + CREATE OR REPLACE PROCEDURE some_procedure(val INTEGER) LANGUAGE plpgsql AS $$ + BEGIN + RAISE NOTICE 'Val, %', val; + END + $$; + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, + }, + { + name: "Create procedure with dependencies that also must be created", + oldSchemaDDL: nil, + newSchemaDDL: []string{ + ` + CREATE SEQUENCE user_id_seq; + + CREATE TABLE users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL + ); + + CREATE OR REPLACE FUNCTION get_name(input_name TEXT) RETURNS TEXT AS $$ + SELECT input_name || '_some_fixed_val' + $$ LANGUAGE SQL; + + CREATE OR REPLACE PROCEDURE "Add User"(name TEXT) LANGUAGE SQL AS $$ + INSERT INTO users (id, name) VALUES (NEXTVAL('user_id_seq'), get_name(name)); + $$; + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, + }, + { + name: "Alter a procedure to have dependencies that must be created", + oldSchemaDDL: []string{ + ` + CREATE TABLE users (); + CREATE OR REPLACE PROCEDURE "Add User"(name TEXT) LANGUAGE SQL AS $$ + INSERT INTO users DEFAULT VALUES; + $$; + `, + }, + newSchemaDDL: []string{ + ` + CREATE SEQUENCE user_id_seq; + + CREATE TABLE users ( + id INTEGER, + name TEXT NOT NULL + ); + + CREATE OR REPLACE FUNCTION get_name(input_name TEXT) RETURNS TEXT AS $$ + SELECT input_name || '_some_fixed_val' + $$ LANGUAGE SQL; + + CREATE OR REPLACE PROCEDURE "Add User"(name TEXT) LANGUAGE SQL AS $$ + INSERT INTO users (id, name) VALUES (NEXTVAL('user_id_seq'), get_name(name)); + $$; + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, + }, + { + name: "Drop procedure and its dependencies", + oldSchemaDDL: []string{ + ` + CREATE SEQUENCE user_id_seq; + + CREATE TABLE users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL + ); + + CREATE OR REPLACE FUNCTION get_name(input_name TEXT) RETURNS TEXT AS $$ + SELECT input_name || '_some_fixed_val' + $$ LANGUAGE SQL; + + CREATE OR REPLACE PROCEDURE "Add User"(name TEXT) LANGUAGE SQL AS $$ + INSERT INTO users (id, name) VALUES (NEXTVAL('user_id_seq'), get_name(name)); + $$; + `, + }, + newSchemaDDL: []string{ + ` + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeDeletesData, + diff.MigrationHazardTypeHasUntrackableDependencies, + }, + }, + { + // This reveals Postgres does not actually track dependencies of procedures outside of creation time. + name: "Drop a procedure's dependencies but not the procedure", + oldSchemaDDL: []string{ + ` + CREATE TABLE users (); + + CREATE OR REPLACE PROCEDURE "Add User"(name TEXT) LANGUAGE SQL AS $$ + INSERT INTO users DEFAULT VALUES; + $$; + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE users(); + + CREATE OR REPLACE PROCEDURE "Add User"(name TEXT) LANGUAGE SQL AS $$ + INSERT INTO users DEFAULT VALUES; + $$; + + -- Drop the table the procedure depends on. This allows us to actually create a database with this schema. + DROP TABLE users; + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeDeletesData, + }, + planOpts: []diff.PlanOpt{ + // Skip plan validation because the acceptance test attempts to regenerate the plan after migrating and + // assert it's empty. As part of this plan regeneration, plan validation attempts to create a database with + // just an "Add User" procedure through normal SQL generation, which inherently fails because the users + // table does not exist. + diff.WithDoNotValidatePlan(), + }, + }, +} + +func (suite *acceptanceTestSuite) TestProcedureTestCases() { + suite.runTestCases(procedureAcceptanceTestCases) +} diff --git a/internal/migration_acceptance_tests/schema_cases_test.go b/internal/migration_acceptance_tests/schema_cases_test.go index 5bb0eb8..23c3f50 100644 --- a/internal/migration_acceptance_tests/schema_cases_test.go +++ b/internal/migration_acceptance_tests/schema_cases_test.go @@ -81,6 +81,12 @@ var schemaAcceptanceTests = []acceptanceTestCase{ CREATE INDEX bar_normal_idx ON bar(bar); CREATE INDEX bar_another_normal_id ON bar(bar, fizz); CREATE UNIQUE INDEX bar_unique_idx on bar(foo, buzz); + + CREATE OR REPLACE PROCEDURE some_procedure(i integer) AS $$ + BEGIN + RAISE NOTICE 'foobar'; + END; + $$ LANGUAGE plpgsql; `, }, newSchemaDDL: []string{ @@ -156,12 +162,18 @@ var schemaAcceptanceTests = []acceptanceTestCase{ CREATE INDEX bar_normal_idx ON bar(bar); CREATE INDEX bar_another_normal_id ON bar(bar, fizz); CREATE UNIQUE INDEX bar_unique_idx on bar(foo, buzz); + + CREATE OR REPLACE PROCEDURE some_procedure(i integer) AS $$ + BEGIN + RAISE NOTICE 'foobar'; + END; + $$ LANGUAGE plpgsql; `, }, expectEmptyPlan: true, }, { - name: "Add schema, drop schema, Add enum, Drop enum, Drop table, Add Table, Drop Seq, Add Seq, Drop Funcs, Add Funcs, Drop Triggers, Add Triggers, Create Extension, Drop Extension, Create Index Using Extension, Add policies, Drop policies", + name: "Add/drop all objects", roles: []string{"role_1"}, oldSchemaDDL: []string{ ` @@ -219,6 +231,10 @@ var schemaAcceptanceTests = []acceptanceTestCase{ CREATE INDEX foobar_normal_idx ON foobar USING hash (fizz); CREATE UNIQUE INDEX foobar_unique_idx ON foobar(foo, fizz DESC); + CREATE OR REPLACE PROCEDURE add_foobar(name TEXT) LANGUAGE SQL AS $$ + INSERT INTO foobar DEFAULT VALUES + $$; + CREATE POLICY foobar_foo_policy ON foobar FOR SELECT TO PUBLIC USING (foo = current_user); CREATE TRIGGER "some trigger" @@ -303,6 +319,10 @@ var schemaAcceptanceTests = []acceptanceTestCase{ CREATE POLICY "New_table_foo_policy" ON "New_table" FOR DELETE TO PUBLIC USING (version > 0); + CREATE OR REPLACE PROCEDURE "new new table"(name TEXT) LANGUAGE SQL AS $$ + INSERT INTO "New_table" (id, version) VALUES (NEXTVAL('schema_3.new_foobar_sequence'), schema_3."new add"(LENGTH(name), 1)) + $$; + CREATE TRIGGER "some trigger" BEFORE UPDATE ON "New_table" FOR EACH ROW diff --git a/internal/queries/queries.sql b/internal/queries/queries.sql index 8a1ea69..8f9c758 100644 --- a/internal/queries/queries.sql +++ b/internal/queries/queries.sql @@ -17,7 +17,7 @@ WHERE -- name: GetTables :many SELECT - c.oid AS oid, + c.oid, c.relname::TEXT AS table_name, table_namespace.nspname::TEXT AS table_schema_name, c.relreplident::TEXT AS replica_identity, @@ -40,7 +40,7 @@ INNER JOIN ON c.relnamespace = table_namespace.oid LEFT JOIN pg_catalog.pg_inherits AS table_inherits - ON table_inherits.inhrelid = c.oid + ON c.oid = table_inherits.inhrelid LEFT JOIN pg_catalog.pg_class AS parent_c ON table_inherits.inhparent = parent_c.oid @@ -98,16 +98,16 @@ SELECT FROM pg_catalog.pg_attribute AS a LEFT JOIN pg_catalog.pg_attrdef AS d - ON (d.adrelid = a.attrelid AND d.adnum = a.attnum) -LEFT JOIN pg_catalog.pg_collation AS coll ON coll.oid = a.attcollation + ON (a.attrelid = d.adrelid AND a.attnum = d.adnum) +LEFT JOIN pg_catalog.pg_collation AS coll ON a.attcollation = coll.oid LEFT JOIN pg_catalog.pg_namespace AS collation_namespace - ON collation_namespace.oid = coll.collnamespace + ON coll.collnamespace = collation_namespace.oid LEFT JOIN identity_col_seq ON - identity_col_seq.owner_relid = a.attrelid - AND identity_col_seq.owner_attnum = a.attnum + a.attrelid = identity_col_seq.owner_relid + AND a.attnum = identity_col_seq.owner_attnum WHERE a.attrelid = $1 AND a.attnum > 0 @@ -116,7 +116,7 @@ ORDER BY a.attnum; -- name: GetIndexes :many SELECT - c.oid AS oid, + c.oid, c.relname::TEXT AS index_name, table_c.relname::TEXT AS table_name, table_namespace.nspname::TEXT AS table_schema_name, @@ -132,21 +132,25 @@ SELECT COALESCE(parent_c.relname, '')::TEXT AS parent_index_name, COALESCE(parent_namespace.nspname, '')::TEXT AS parent_index_schema_name, ( - SELECT ARRAY_AGG(att.attname ORDER BY indkey_ord.ord) + SELECT + ARRAY_AGG( + att.attname + ORDER BY indkey_ord.ord + ) FROM UNNEST(i.indkey) WITH ORDINALITY AS indkey_ord (attnum, ord) INNER JOIN pg_catalog.pg_attribute AS att - ON att.attrelid = table_c.oid AND att.attnum = indkey_ord.attnum + ON att.attrelid = table_c.oid AND indkey_ord.attnum = att.attnum )::TEXT [] AS column_names, COALESCE(con.conislocal, false) AS constraint_is_local FROM pg_catalog.pg_class AS c -INNER JOIN pg_catalog.pg_index AS i ON (i.indexrelid = c.oid) -INNER JOIN pg_catalog.pg_class AS table_c ON (table_c.oid = i.indrelid) +INNER JOIN pg_catalog.pg_index AS i ON (c.oid = i.indexrelid) +INNER JOIN pg_catalog.pg_class AS table_c ON (i.indrelid = table_c.oid) INNER JOIN pg_catalog.pg_namespace AS table_namespace ON table_c.relnamespace = table_namespace.oid LEFT JOIN pg_catalog.pg_constraint AS con - ON (con.conindid = c.oid AND con.contype IN ('p', 'u', null)) + ON (c.oid = con.conindid AND con.contype IN ('p', 'u', null)) LEFT JOIN pg_catalog.pg_inherits AS idx_inherits ON (c.oid = idx_inherits.inhrelid) @@ -222,7 +226,7 @@ WHERE AND pg_constraint.contype = 'f' AND pg_constraint.conislocal; --- name: GetFunctions :many +-- name: GetProcs :many SELECT pg_proc.oid, pg_proc.proname::TEXT AS func_name, @@ -238,12 +242,12 @@ INNER JOIN ON pg_proc.pronamespace = proc_namespace.oid INNER JOIN pg_catalog.pg_language AS proc_lang - ON proc_lang.oid = pg_proc.prolang + ON pg_proc.prolang = proc_lang.oid WHERE proc_namespace.nspname NOT IN ('pg_catalog', 'information_schema') AND proc_namespace.nspname !~ '^pg_toast' AND proc_namespace.nspname !~ '^pg_temp' - AND pg_proc.prokind = 'f' + AND pg_proc.prokind = $1 -- Exclude functions belonging to extensions AND NOT EXISTS ( SELECT depend.objid @@ -358,7 +362,7 @@ SELECT FROM pg_catalog.pg_namespace AS extension_namespace INNER JOIN pg_catalog.pg_extension AS ext - ON ext.extnamespace = extension_namespace.oid + ON extension_namespace.oid = ext.extnamespace WHERE extension_namespace.nspname NOT IN ('pg_catalog', 'information_schema') AND extension_namespace.nspname !~ '^pg_toast' @@ -370,14 +374,18 @@ SELECT pg_type.typname::TEXT AS enum_name, type_namespace.nspname::TEXT AS enum_schema_name, ( - SELECT ARRAY_AGG(pg_enum.enumlabel ORDER BY pg_enum.enumsortorder) + SELECT + ARRAY_AGG( + pg_enum.enumlabel + ORDER BY pg_enum.enumsortorder + ) FROM pg_catalog.pg_enum WHERE pg_enum.enumtypid = pg_type.oid )::TEXT [] AS enum_labels FROM pg_catalog.pg_type AS pg_type INNER JOIN pg_catalog.pg_namespace AS type_namespace - ON type_namespace.oid = pg_type.typnamespace + ON pg_type.typnamespace = type_namespace.oid WHERE pg_type.typtype = 'e' AND type_namespace.nspname NOT IN ('pg_catalog', 'information_schema') @@ -414,7 +422,7 @@ SELECT table_namespace.nspname::TEXT AS owning_table_schema_name, pol.polpermissive AS is_permissive, ( - SELECT ARRAY_AGG(rolname) + SELECT ARRAY_AGG(roles.rolname) FROM roles WHERE roles.oid = ANY(pol.polroles) )::TEXT [] AS applies_to, diff --git a/internal/queries/queries.sql.go b/internal/queries/queries.sql.go index 746a38f..e63f338 100644 --- a/internal/queries/queries.sql.go +++ b/internal/queries/queries.sql.go @@ -133,16 +133,16 @@ SELECT FROM pg_catalog.pg_attribute AS a LEFT JOIN pg_catalog.pg_attrdef AS d - ON (d.adrelid = a.attrelid AND d.adnum = a.attnum) -LEFT JOIN pg_catalog.pg_collation AS coll ON coll.oid = a.attcollation + ON (a.attrelid = d.adrelid AND a.attnum = d.adnum) +LEFT JOIN pg_catalog.pg_collation AS coll ON a.attcollation = coll.oid LEFT JOIN pg_catalog.pg_namespace AS collation_namespace - ON collation_namespace.oid = coll.collnamespace + ON coll.collnamespace = collation_namespace.oid LEFT JOIN identity_col_seq ON - identity_col_seq.owner_relid = a.attrelid - AND identity_col_seq.owner_attnum = a.attnum + a.attrelid = identity_col_seq.owner_relid + AND a.attnum = identity_col_seq.owner_attnum WHERE a.attrelid = $1 AND a.attnum > 0 @@ -265,14 +265,18 @@ SELECT pg_type.typname::TEXT AS enum_name, type_namespace.nspname::TEXT AS enum_schema_name, ( - SELECT ARRAY_AGG(pg_enum.enumlabel ORDER BY pg_enum.enumsortorder) + SELECT + ARRAY_AGG( + pg_enum.enumlabel + ORDER BY pg_enum.enumsortorder + ) FROM pg_catalog.pg_enum WHERE pg_enum.enumtypid = pg_type.oid )::TEXT [] AS enum_labels FROM pg_catalog.pg_type AS pg_type INNER JOIN pg_catalog.pg_namespace AS type_namespace - ON type_namespace.oid = pg_type.typnamespace + ON pg_type.typnamespace = type_namespace.oid WHERE pg_type.typtype = 'e' AND type_namespace.nspname NOT IN ('pg_catalog', 'information_schema') @@ -327,7 +331,7 @@ SELECT FROM pg_catalog.pg_namespace AS extension_namespace INNER JOIN pg_catalog.pg_extension AS ext - ON ext.extnamespace = extension_namespace.oid + ON extension_namespace.oid = ext.extnamespace WHERE extension_namespace.nspname NOT IN ('pg_catalog', 'information_schema') AND extension_namespace.nspname !~ '^pg_toast' @@ -439,81 +443,9 @@ func (q *Queries) GetForeignKeyConstraints(ctx context.Context) ([]GetForeignKey return items, nil } -const getFunctions = `-- name: GetFunctions :many -SELECT - pg_proc.oid, - pg_proc.proname::TEXT AS func_name, - proc_namespace.nspname::TEXT AS func_schema_name, - proc_lang.lanname::TEXT AS func_lang, - pg_catalog.pg_get_function_identity_arguments( - pg_proc.oid - ) AS func_identity_arguments, - pg_catalog.pg_get_functiondef(pg_proc.oid) AS func_def -FROM pg_catalog.pg_proc -INNER JOIN - pg_catalog.pg_namespace AS proc_namespace - ON pg_proc.pronamespace = proc_namespace.oid -INNER JOIN - pg_catalog.pg_language AS proc_lang - ON proc_lang.oid = pg_proc.prolang -WHERE - proc_namespace.nspname NOT IN ('pg_catalog', 'information_schema') - AND proc_namespace.nspname !~ '^pg_toast' - AND proc_namespace.nspname !~ '^pg_temp' - AND pg_proc.prokind = 'f' - -- Exclude functions belonging to extensions - AND NOT EXISTS ( - SELECT depend.objid - FROM pg_catalog.pg_depend AS depend - WHERE - depend.classid = 'pg_proc'::REGCLASS - AND depend.objid = pg_proc.oid - AND depend.deptype = 'e' - ) -` - -type GetFunctionsRow struct { - Oid interface{} - FuncName string - FuncSchemaName string - FuncLang string - FuncIdentityArguments string - FuncDef string -} - -func (q *Queries) GetFunctions(ctx context.Context) ([]GetFunctionsRow, error) { - rows, err := q.db.QueryContext(ctx, getFunctions) - if err != nil { - return nil, err - } - defer rows.Close() - var items []GetFunctionsRow - for rows.Next() { - var i GetFunctionsRow - if err := rows.Scan( - &i.Oid, - &i.FuncName, - &i.FuncSchemaName, - &i.FuncLang, - &i.FuncIdentityArguments, - &i.FuncDef, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - const getIndexes = `-- name: GetIndexes :many SELECT - c.oid AS oid, + c.oid, c.relname::TEXT AS index_name, table_c.relname::TEXT AS table_name, table_namespace.nspname::TEXT AS table_schema_name, @@ -529,21 +461,25 @@ SELECT COALESCE(parent_c.relname, '')::TEXT AS parent_index_name, COALESCE(parent_namespace.nspname, '')::TEXT AS parent_index_schema_name, ( - SELECT ARRAY_AGG(att.attname ORDER BY indkey_ord.ord) + SELECT + ARRAY_AGG( + att.attname + ORDER BY indkey_ord.ord + ) FROM UNNEST(i.indkey) WITH ORDINALITY AS indkey_ord (attnum, ord) INNER JOIN pg_catalog.pg_attribute AS att - ON att.attrelid = table_c.oid AND att.attnum = indkey_ord.attnum + ON att.attrelid = table_c.oid AND indkey_ord.attnum = att.attnum )::TEXT [] AS column_names, COALESCE(con.conislocal, false) AS constraint_is_local FROM pg_catalog.pg_class AS c -INNER JOIN pg_catalog.pg_index AS i ON (i.indexrelid = c.oid) -INNER JOIN pg_catalog.pg_class AS table_c ON (table_c.oid = i.indrelid) +INNER JOIN pg_catalog.pg_index AS i ON (c.oid = i.indexrelid) +INNER JOIN pg_catalog.pg_class AS table_c ON (i.indrelid = table_c.oid) INNER JOIN pg_catalog.pg_namespace AS table_namespace ON table_c.relnamespace = table_namespace.oid LEFT JOIN pg_catalog.pg_constraint AS con - ON (con.conindid = c.oid AND con.contype IN ('p', 'u', null)) + ON (c.oid = con.conindid AND con.contype IN ('p', 'u', null)) LEFT JOIN pg_catalog.pg_inherits AS idx_inherits ON (c.oid = idx_inherits.inhrelid) @@ -637,7 +573,7 @@ SELECT table_namespace.nspname::TEXT AS owning_table_schema_name, pol.polpermissive AS is_permissive, ( - SELECT ARRAY_AGG(rolname) + SELECT ARRAY_AGG(roles.rolname) FROM roles WHERE roles.oid = ANY(pol.polroles) )::TEXT [] AS applies_to, @@ -715,6 +651,78 @@ func (q *Queries) GetPolicies(ctx context.Context) ([]GetPoliciesRow, error) { return items, nil } +const getProcs = `-- name: GetProcs :many +SELECT + pg_proc.oid, + pg_proc.proname::TEXT AS func_name, + proc_namespace.nspname::TEXT AS func_schema_name, + proc_lang.lanname::TEXT AS func_lang, + pg_catalog.pg_get_function_identity_arguments( + pg_proc.oid + ) AS func_identity_arguments, + pg_catalog.pg_get_functiondef(pg_proc.oid) AS func_def +FROM pg_catalog.pg_proc +INNER JOIN + pg_catalog.pg_namespace AS proc_namespace + ON pg_proc.pronamespace = proc_namespace.oid +INNER JOIN + pg_catalog.pg_language AS proc_lang + ON pg_proc.prolang = proc_lang.oid +WHERE + proc_namespace.nspname NOT IN ('pg_catalog', 'information_schema') + AND proc_namespace.nspname !~ '^pg_toast' + AND proc_namespace.nspname !~ '^pg_temp' + AND pg_proc.prokind = $1 + -- Exclude functions belonging to extensions + AND NOT EXISTS ( + SELECT depend.objid + FROM pg_catalog.pg_depend AS depend + WHERE + depend.classid = 'pg_proc'::REGCLASS + AND depend.objid = pg_proc.oid + AND depend.deptype = 'e' + ) +` + +type GetProcsRow struct { + Oid interface{} + FuncName string + FuncSchemaName string + FuncLang string + FuncIdentityArguments string + FuncDef string +} + +func (q *Queries) GetProcs(ctx context.Context, prokind interface{}) ([]GetProcsRow, error) { + rows, err := q.db.QueryContext(ctx, getProcs, prokind) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetProcsRow + for rows.Next() { + var i GetProcsRow + if err := rows.Scan( + &i.Oid, + &i.FuncName, + &i.FuncSchemaName, + &i.FuncLang, + &i.FuncIdentityArguments, + &i.FuncDef, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getSchemas = `-- name: GetSchemas :many SELECT nspname::TEXT AS schema_name FROM pg_catalog.pg_namespace @@ -858,7 +866,7 @@ func (q *Queries) GetSequences(ctx context.Context) ([]GetSequencesRow, error) { const getTables = `-- name: GetTables :many SELECT - c.oid AS oid, + c.oid, c.relname::TEXT AS table_name, table_namespace.nspname::TEXT AS table_schema_name, c.relreplident::TEXT AS replica_identity, @@ -881,7 +889,7 @@ INNER JOIN ON c.relnamespace = table_namespace.oid LEFT JOIN pg_catalog.pg_inherits AS table_inherits - ON table_inherits.inhrelid = c.oid + ON c.oid = table_inherits.inhrelid LEFT JOIN pg_catalog.pg_class AS parent_c ON table_inherits.inhparent = parent_c.oid diff --git a/internal/schema/schema.go b/internal/schema/schema.go index 68d66d8..4849cba 100644 --- a/internal/schema/schema.go +++ b/internal/schema/schema.go @@ -57,11 +57,12 @@ type Schema struct { ForeignKeyConstraints []ForeignKeyConstraint Sequences []Sequence Functions []Function + Procedures []Procedure Triggers []Trigger } -// Normalize normalizes the schema (alphabetically sorts tables and columns in tables) -// Useful for hashing and testing +// Normalize normalizes the schema (alphabetically sorts tables and columns in tables). +// Useful for hashing and testing. func (s Schema) Normalize() Schema { s.NamedSchemas = sortSchemaObjectsByName(s.NamedSchemas) s.Extensions = sortSchemaObjectsByName(s.Extensions) @@ -84,6 +85,7 @@ func (s Schema) Normalize() Schema { } s.Functions = normFunctions + s.Procedures = sortSchemaObjectsByName(s.Procedures) s.Triggers = sortSchemaObjectsByName(s.Triggers) return s @@ -372,6 +374,14 @@ type Function struct { DependsOnFunctions []SchemaQualifiedName } +type Procedure struct { + SchemaQualifiedName + // Def is the statement required to completely (re)create + // the procedure, as returned by `pg_get_functiondef`. It is a CREATE OR REPLACE + // statement. + Def string +} + var ( // The first matching group is the "CREATE ". The second matching group is the rest of the statement triggerToOrReplaceRegex = regexp.MustCompile("^(CREATE )(.*)$") @@ -621,6 +631,13 @@ func (s *schemaFetcher) getSchema(ctx context.Context) (Schema, error) { return Schema{}, fmt.Errorf("starting functions future: %w", err) } + proceduresFuture, err := concurrent.SubmitFuture(ctx, goroutineRunner, func() ([]Procedure, error) { + return s.fetchProcedures(ctx) + }) + if err != nil { + return Schema{}, fmt.Errorf("starting functions future: %w", err) + } + triggersFuture, err := concurrent.SubmitFuture(ctx, goroutineRunner, func() ([]Trigger, error) { return s.fetchTriggers(ctx) }) @@ -668,6 +685,11 @@ func (s *schemaFetcher) getSchema(ctx context.Context) (Schema, error) { return Schema{}, fmt.Errorf("getting functions: %w", err) } + procedures, err := proceduresFuture.Get(ctx) + if err != nil { + return Schema{}, fmt.Errorf("getting procedures: %w", err) + } + triggers, err := triggersFuture.Get(ctx) if err != nil { return Schema{}, fmt.Errorf("getting triggers: %w", err) @@ -682,6 +704,7 @@ func (s *schemaFetcher) getSchema(ctx context.Context) (Schema, error) { ForeignKeyConstraints: fkCons, Sequences: sequences, Functions: functions, + Procedures: procedures, Triggers: triggers, }, nil } @@ -1109,9 +1132,9 @@ func (s *schemaFetcher) fetchSequences(ctx context.Context) ([]Sequence, error) } func (s *schemaFetcher) fetchFunctions(ctx context.Context) ([]Function, error) { - rawFunctions, err := s.q.GetFunctions(ctx) + rawFunctions, err := s.q.GetProcs(ctx, 'f') if err != nil { - return nil, fmt.Errorf("GetFunctions: %w", err) + return nil, fmt.Errorf("GetProcs: %w", err) } goroutineRunner := s.goroutineRunnerFactory() @@ -1143,14 +1166,14 @@ func (s *schemaFetcher) fetchFunctions(ctx context.Context) ([]Function, error) return functions, nil } -func (s *schemaFetcher) buildFunction(ctx context.Context, rawFunction queries.GetFunctionsRow) (Function, error) { +func (s *schemaFetcher) buildFunction(ctx context.Context, rawFunction queries.GetProcsRow) (Function, error) { dependsOnFunctions, err := s.fetchDependsOnFunctions(ctx, "pg_proc", rawFunction.Oid) if err != nil { return Function{}, fmt.Errorf("fetchDependsOnFunctions(%s): %w", rawFunction.Oid, err) } return Function{ - SchemaQualifiedName: buildFuncName(rawFunction.FuncName, rawFunction.FuncIdentityArguments, rawFunction.FuncSchemaName), + SchemaQualifiedName: buildProcName(rawFunction.FuncName, rawFunction.FuncIdentityArguments, rawFunction.FuncSchemaName), FunctionDef: rawFunction.FuncDef, Language: rawFunction.FuncLang, DependsOnFunctions: dependsOnFunctions, @@ -1168,12 +1191,38 @@ func (s *schemaFetcher) fetchDependsOnFunctions(ctx context.Context, systemCatal var functionNames []SchemaQualifiedName for _, rawFunction := range dependsOnFunctions { - functionNames = append(functionNames, buildFuncName(rawFunction.FuncName, rawFunction.FuncIdentityArguments, rawFunction.FuncSchemaName)) + functionNames = append(functionNames, buildProcName(rawFunction.FuncName, rawFunction.FuncIdentityArguments, rawFunction.FuncSchemaName)) } return functionNames, nil } +func (s *schemaFetcher) fetchProcedures(ctx context.Context) ([]Procedure, error) { + rawProcedures, err := s.q.GetProcs(ctx, 'p') + if err != nil { + return nil, fmt.Errorf("GetProcs: %w", err) + } + + var procedures []Procedure + for _, rawProcedure := range rawProcedures { + p := Procedure{ + SchemaQualifiedName: buildProcName(rawProcedure.FuncName, rawProcedure.FuncIdentityArguments, rawProcedure.FuncSchemaName), + Def: rawProcedure.FuncDef, + } + procedures = append(procedures, p) + } + + procedures = filterSliceByName( + procedures, + func(function Procedure) SchemaQualifiedName { + return function.SchemaQualifiedName + }, + s.nameFilter, + ) + + return procedures, nil +} + type policyAndTable struct { policy Policy table SchemaQualifiedName @@ -1226,7 +1275,7 @@ func (s *schemaFetcher) fetchTriggers(ctx context.Context) ([]Trigger, error) { triggers = append(triggers, Trigger{ EscapedName: EscapeIdentifier(rawTrigger.TriggerName), OwningTable: buildNameFromUnescaped(rawTrigger.OwningTableName, rawTrigger.OwningTableSchemaName), - Function: buildFuncName(rawTrigger.FuncName, rawTrigger.FuncIdentityArguments, rawTrigger.FuncSchemaName), + Function: buildProcName(rawTrigger.FuncName, rawTrigger.FuncIdentityArguments, rawTrigger.FuncSchemaName), GetTriggerDefStmt: GetTriggerDefStatement(rawTrigger.TriggerDef), }) } @@ -1245,7 +1294,9 @@ func (s *schemaFetcher) fetchTriggers(ctx context.Context) ([]Trigger, error) { return triggers, nil } -func buildFuncName(name, identityArguments, schemaName string) SchemaQualifiedName { +// buildProcName is used to build the schema qualified name for a proc (function, procedure), i.e., anything +// identified by a name AND its arguments. +func buildProcName(name, identityArguments, schemaName string) SchemaQualifiedName { return SchemaQualifiedName{ SchemaName: schemaName, EscapedName: fmt.Sprintf("\"%s\"(%s)", name, identityArguments), diff --git a/internal/schema/schema_test.go b/internal/schema/schema_test.go index 4d0b077..7f7eb0b 100644 --- a/internal/schema/schema_test.go +++ b/internal/schema/schema_test.go @@ -142,6 +142,20 @@ var ( -- Reference a function in a filtered out schema. The trigger should still be included. EXECUTE PROCEDURE schema_filtered_1.increment_version(); + CREATE PROCEDURE schema_2.some_insert_procedure(a INTEGER, b INTEGER) + LANGUAGE SQL + BEGIN ATOMIC + INSERT INTO schema_2.foo DEFAULT VALUES; + END; + + CREATE PROCEDURE some_plpgsql_procedure(foobar NUMERIC) + LANGUAGE plpgsql + AS $$ + BEGIN + RAISE NOTICE 'some notice'; + END + $$; + -- Create table with conflicting name that has check constraints CREATE TABLE schema_1.foo( id INT NOT NULL, @@ -177,6 +191,12 @@ var ( ON UPDATE CASCADE ON DELETE CASCADE NOT VALID; + -- Validate procedures are filtered out + CREATE PROCEDURE schema_filtered_1.some_filtered_procedure(a INTEGER) + LANGUAGE SQL + BEGIN ATOMIC + INSERT INTO schema_2.foo DEFAULT VALUES; + END; -- Validate triggers are filtered out CREATE TRIGGER some_trigger BEFORE UPDATE ON schema_filtered_1.foo_fk @@ -191,7 +211,7 @@ var ( TO PUBLIC USING (version > 0); `}, - expectedHash: "4f6a01ac1a078624", + expectedHash: "500097cd4fa6f068", expectedSchema: Schema{ NamedSchemas: []NamedSchema{ {Name: "public"}, @@ -420,6 +440,16 @@ var ( Language: "plpgsql", }, }, + Procedures: []Procedure{ + { + SchemaQualifiedName: SchemaQualifiedName{SchemaName: "public", EscapedName: "\"some_plpgsql_procedure\"(IN foobar numeric)"}, + Def: "CREATE OR REPLACE PROCEDURE public.some_plpgsql_procedure(IN foobar numeric)\n LANGUAGE plpgsql\nAS $procedure$\n\t\t\t\tBEGIN\n\t\t\t\t\tRAISE NOTICE 'some notice';\n\t\t\t\tEND\n\t\t\t\t$procedure$\n", + }, + { + SchemaQualifiedName: SchemaQualifiedName{SchemaName: "schema_2", EscapedName: "\"some_insert_procedure\"(IN a integer, IN b integer)"}, + Def: "CREATE OR REPLACE PROCEDURE schema_2.some_insert_procedure(IN a integer, IN b integer)\n LANGUAGE sql\nBEGIN ATOMIC\n INSERT INTO schema_2.foo DEFAULT VALUES;\nEND\n", + }, + }, Triggers: []Trigger{ { EscapedName: "\"some_trigger\"", @@ -494,7 +524,7 @@ var ( ALTER TABLE foo_fk_1 ADD CONSTRAINT foo_fk_1_fk FOREIGN KEY (author, content) REFERENCES foo_1 (author, content) NOT VALID; `}, - expectedHash: "14fc890b05a1fa7b", + expectedHash: "cf473d75363e9f77", expectedSchema: Schema{ NamedSchemas: []NamedSchema{ {Name: "public"}, @@ -1056,7 +1086,7 @@ var ( CREATE TYPE pg_temp.color AS ENUM ('red', 'green', 'blue'); `}, // Assert empty schema hash, since we want to validate specifically that this hash is deterministic - expectedHash: "e63f48c273376e85", + expectedHash: "83bae9b012ee367e", expectedSchema: Schema{ NamedSchemas: []NamedSchema{ {Name: "public"}, diff --git a/pkg/diff/function_sql_vertex_generator.go b/pkg/diff/function_sql_vertex_generator.go new file mode 100644 index 0000000..088ff6a --- /dev/null +++ b/pkg/diff/function_sql_vertex_generator.go @@ -0,0 +1,108 @@ +package diff + +import ( + "fmt" + + "github.com/google/go-cmp/cmp" + "github.com/stripe/pg-schema-diff/internal/schema" +) + +type functionSQLVertexGenerator struct { + // functionsInNewSchemaByName is a map of function name to functions in the new schema. + // These functions are not necessarily new + functionsInNewSchemaByName map[string]schema.Function +} + +func newFunctionSqlVertexGenerator(functionsInNewSchemaByName map[string]schema.Function) sqlVertexGenerator[schema.Function, functionDiff] { + return legacyToNewSqlVertexGenerator[schema.Function, functionDiff](&functionSQLVertexGenerator{ + functionsInNewSchemaByName: functionsInNewSchemaByName, + }) +} + +func (f *functionSQLVertexGenerator) Add(function schema.Function) ([]Statement, error) { + var hazards []MigrationHazard + if !canFunctionDependenciesBeTracked(function) { + hazards = append(hazards, MigrationHazard{ + Type: MigrationHazardTypeHasUntrackableDependencies, + Message: "Dependencies, i.e. other functions used in the function body, of non-sql functions cannot be tracked. " + + "As a result, we cannot guarantee that function dependencies are ordered properly relative to this " + + "statement. For adds, this means you need to ensure that all functions this function depends on are " + + "created/altered before this statement.", + }) + } + return []Statement{{ + DDL: function.FunctionDef, + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + Hazards: hazards, + }}, nil +} + +func (f *functionSQLVertexGenerator) Delete(function schema.Function) ([]Statement, error) { + var hazards []MigrationHazard + if !canFunctionDependenciesBeTracked(function) { + hazards = append(hazards, MigrationHazard{ + Type: MigrationHazardTypeHasUntrackableDependencies, + Message: "Dependencies, i.e. other functions used in the function body, of non-sql functions cannot be " + + "tracked. As a result, we cannot guarantee that function dependencies are ordered properly relative to " + + "this statement. For drops, this means you need to ensure that all functions this function depends on " + + "are dropped after this statement.", + }) + } + return []Statement{{ + DDL: fmt.Sprintf("DROP FUNCTION %s", function.GetFQEscapedName()), + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + Hazards: hazards, + }}, nil +} + +func (f *functionSQLVertexGenerator) Alter(diff functionDiff) ([]Statement, error) { + // We are assuming the function has been normalized, i.e., we don't have to worry DependsOnFunctions ordering + // causing a false positive diff detected. + if cmp.Equal(diff.old, diff.new) { + return nil, nil + } + return f.Add(diff.new) +} + +func canFunctionDependenciesBeTracked(function schema.Function) bool { + return function.Language == "sql" +} + +func (f *functionSQLVertexGenerator) GetSQLVertexId(function schema.Function, diffType diffType) sqlVertexId { + return buildFunctionVertexId(function.SchemaQualifiedName, diffType) +} + +func buildFunctionVertexId(name schema.SchemaQualifiedName, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("function", name.GetFQEscapedName(), diffType) +} + +func (f *functionSQLVertexGenerator) GetAddAlterDependencies(newFunction, oldFunction schema.Function) ([]dependency, error) { + // Since functions can just be `CREATE OR REPLACE`, there will never be a case where a function is + // added and dropped in the same migration. Thus, we don't need a dependency on the delete vertex of a function + // because there won't be one if it is being added/altered + var deps []dependency + for _, depFunction := range newFunction.DependsOnFunctions { + deps = append(deps, mustRun(f.GetSQLVertexId(newFunction, diffTypeAddAlter)).after(buildFunctionVertexId(depFunction, diffTypeAddAlter))) + } + + if !cmp.Equal(oldFunction, schema.Function{}) { + // If the function is being altered: + // If the old version of the function calls other functions that are being deleted come, those deletions + // must come after the function is altered, so it is no longer dependent on those dropped functions + for _, depFunction := range oldFunction.DependsOnFunctions { + deps = append(deps, mustRun(f.GetSQLVertexId(newFunction, diffTypeAddAlter)).before(buildFunctionVertexId(depFunction, diffTypeDelete))) + } + } + + return deps, nil +} + +func (f *functionSQLVertexGenerator) GetDeleteDependencies(function schema.Function) ([]dependency, error) { + var deps []dependency + for _, depFunction := range function.DependsOnFunctions { + deps = append(deps, mustRun(f.GetSQLVertexId(function, diffTypeDelete)).before(buildFunctionVertexId(depFunction, diffTypeDelete))) + } + return deps, nil +} diff --git a/pkg/diff/procedure_sql_vertex_generator.go b/pkg/diff/procedure_sql_vertex_generator.go new file mode 100644 index 0000000..d9f725b --- /dev/null +++ b/pkg/diff/procedure_sql_vertex_generator.go @@ -0,0 +1,117 @@ +package diff + +import ( + "fmt" + + "github.com/google/go-cmp/cmp" + "github.com/stripe/pg-schema-diff/internal/schema" +) + +type procedureSQLVertexGenerator struct { + newSchema schema.Schema +} + +func newProcedureSqlVertexGenerator(newSchema schema.Schema) sqlVertexGenerator[schema.Procedure, procedureDiff] { + return &procedureSQLVertexGenerator{ + newSchema: newSchema, + } +} + +func (p procedureSQLVertexGenerator) Add(s schema.Procedure) (partialSQLGraph, error) { + // Procedures can't be added until all dependencies have been added. Weirdly, Postgres ONLY enforces these + // dependencies at creation time and not after...so we will make a best effort to order this statement after + // all other dependencies that procedures might depend on. + + var deps []dependency + + // Run after all tables have been added/altered, since a procedure might query a table. + for _, t := range p.newSchema.Tables { + deps = append(deps, mustRun(buildProcedureVertexId(s.SchemaQualifiedName, diffTypeAddAlter)).after(buildTableVertexId(t.SchemaQualifiedName, diffTypeAddAlter))) + } + + // Run after all functions, since a procedure might call a function. + for _, f := range p.newSchema.Functions { + deps = append(deps, mustRun(buildProcedureVertexId(s.SchemaQualifiedName, diffTypeAddAlter)).after(buildFunctionVertexId(f.SchemaQualifiedName, diffTypeAddAlter))) + } + + // Run after all sequences, since a procedure might call a sequence. + for _, seq := range p.newSchema.Sequences { + deps = append(deps, mustRun(buildProcedureVertexId(s.SchemaQualifiedName, diffTypeAddAlter)).after(buildSequenceVertexId(seq.SchemaQualifiedName, diffTypeAddAlter))) + } + + return partialSQLGraph{ + vertices: []sqlVertex{{ + id: buildProcedureVertexId(s.SchemaQualifiedName, diffTypeAddAlter), + priority: sqlPrioritySooner, + statements: []Statement{{ + DDL: s.Def, + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + Hazards: []MigrationHazard{{ + Type: MigrationHazardTypeHasUntrackableDependencies, + Message: "Dependencies of procedures are not tracked by Postgres. " + + "As a result, we cannot guarantee that this procedure's dependencies are ordered properly relative to " + + "this statement. For adds, this means you need to ensure that all objects this function depends on " + + "are added before this statement.", + }}, + }}, + }}, + dependencies: deps, + }, nil +} + +func (p procedureSQLVertexGenerator) Delete(s schema.Procedure) (partialSQLGraph, error) { + // Stored procedure dependencies can't be tracked...so they can either be deleted earlier or later. We will + // delete earlier, since a procedure is more likely to depend on objects that being depended on. Thus, we will have + // a stored procedure drop before other objects that might depend on it. + var deps []dependency + + // Run before all tables have been added/altered, since a procedure might query a table. This does not work for columns + // being dropped because column drops are not "trackable" from external SQL generators until + // https://github.com/stripe/pg-schema-diff/issues/131 is fully implemented. + for _, t := range p.newSchema.Tables { + deps = append(deps, mustRun(buildProcedureVertexId(s.SchemaQualifiedName, diffTypeDelete)).after(buildTableVertexId(t.SchemaQualifiedName, diffTypeAddAlter))) + } + + // Run before all functions, since a procedure might call a function. + for _, f := range p.newSchema.Functions { + deps = append(deps, mustRun(buildProcedureVertexId(s.SchemaQualifiedName, diffTypeDelete)).after(buildFunctionVertexId(f.SchemaQualifiedName, diffTypeAddAlter))) + } + + // Run before all sequences, since a procedure might call a sequence. + for _, seq := range p.newSchema.Sequences { + deps = append(deps, mustRun(buildProcedureVertexId(s.SchemaQualifiedName, diffTypeDelete)).after(buildSequenceVertexId(seq.SchemaQualifiedName, diffTypeAddAlter))) + } + + return partialSQLGraph{ + vertices: []sqlVertex{{ + id: buildProcedureVertexId(s.SchemaQualifiedName, diffTypeDelete), + priority: sqlPriorityLater, + statements: []Statement{{ + DDL: fmt.Sprintf("DROP PROCEDURE %s", s.GetFQEscapedName()), + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + Hazards: []MigrationHazard{{ + Type: MigrationHazardTypeHasUntrackableDependencies, + Message: "Dependencies of procedures are not tracked by Postgres. " + + "As a result, we cannot guarantee that this procedure's dependencies are ordered properly relative to " + + "this statement. For drops, this means you need to ensure that all objects this function depends on " + + "are dropped after this statement.", + }}, + }}, + }}, + dependencies: deps, + }, nil +} + +func (p procedureSQLVertexGenerator) Alter(d procedureDiff) (partialSQLGraph, error) { + if cmp.Equal(d.old, d.new) { + return partialSQLGraph{}, nil + } + // New adds or replaces the procedure. + return p.Add(d.new) +} + +func buildProcedureVertexId(name schema.SchemaQualifiedName, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("procedure", name.GetFQEscapedName(), diffType) +} diff --git a/pkg/diff/sql_generator.go b/pkg/diff/sql_generator.go index 7919d7d..caf9aec 100644 --- a/pkg/diff/sql_generator.go +++ b/pkg/diff/sql_generator.go @@ -40,13 +40,6 @@ var ( // It is recommended to ignore column ordering changes to column order ErrColumnOrderingChanged = fmt.Errorf("column ordering changed: %w", ErrNotImplemented) - migrationHazardAddAlterFunctionCannotTrackDependencies = MigrationHazard{ - Type: MigrationHazardTypeHasUntrackableDependencies, - Message: "Dependencies, i.e. other functions used in the function body, of non-sql functions cannot be tracked. " + - "As a result, we cannot guarantee that function dependencies are ordered properly relative to this " + - "statement. For adds, this means you need to ensure that all functions this function depends on are " + - "created/altered before this statement.", - } migrationHazardIndexDroppedQueryPerf = MigrationHazard{ Type: MigrationHazardTypeIndexDropped, Message: "Dropping this index means queries that use this index might perform worse because " + @@ -129,6 +122,10 @@ type ( oldAndNew[schema.Function] } + procedureDiff struct { + oldAndNew[schema.Procedure] + } + triggerDiff struct { oldAndNew[schema.Trigger] } @@ -144,6 +141,7 @@ type schemaDiff struct { foreignKeyConstraintDiffs listDiff[schema.ForeignKeyConstraint, foreignKeyConstraintDiff] sequenceDiffs listDiff[schema.Sequence, sequenceDiff] functionDiffs listDiff[schema.Function, functionDiff] + proceduresDiffs listDiff[schema.Procedure, procedureDiff] triggerDiffs listDiff[schema.Trigger, triggerDiff] } @@ -288,6 +286,18 @@ func buildSchemaDiff(old, new schema.Schema) (schemaDiff, bool, error) { return schemaDiff{}, false, fmt.Errorf("diffing functions: %w", err) } + procedureDiffs, err := diffLists(old.Procedures, new.Procedures, func(old, new schema.Procedure, _, _ int) (procedureDiff, bool, error) { + return procedureDiff{ + oldAndNew[schema.Procedure]{ + old: old, + new: new, + }, + }, false, nil + }) + if err != nil { + return schemaDiff{}, false, fmt.Errorf("diffing procedures: %w", err) + } + triggerDiffs, err := diffLists(old.Triggers, new.Triggers, func(old, new schema.Trigger, _, _ int) (triggerDiff, bool, error) { if _, isOnNewTable := addedTablesByName[new.OwningTable.GetName()]; isOnNewTable { // If the table is new, then it must be re-created (this occurs if the base table has been @@ -318,6 +328,7 @@ func buildSchemaDiff(old, new schema.Schema) (schemaDiff, bool, error) { foreignKeyConstraintDiffs: foreignKeyConstraintDiffs, sequenceDiffs: sequencesDiffs, functionDiffs: functionDiffs, + proceduresDiffs: procedureDiffs, triggerDiffs: triggerDiffs, }, false, nil } @@ -499,6 +510,7 @@ func (schemaSQLGenerator) Alter(diff schemaDiff) ([]Statement, error) { tablesInNewSchemaByName := buildSchemaObjByNameMap(diff.new.Tables) deletedTablesByName := buildSchemaObjByNameMap(diff.tableDiffs.deletes) addedTablesByName := buildSchemaObjByNameMap(diff.tableDiffs.adds) + functionsInNewSchemaByName := buildSchemaObjByNameMap(diff.new.Functions) namedSchemaStatements, err := diff.namedSchemaDiffs.resolveToSQLGroupedByEffect(&namedSchemaSQLGenerator{}) if err != nil { @@ -580,16 +592,20 @@ func (schemaSQLGenerator) Alter(diff schemaDiff) ([]Statement, error) { } partialGraph = concatPartialGraphs(partialGraph, sequenceOwnershipsPartialGraph) - functionsInNewSchemaByName := buildSchemaObjByNameMap(diff.new.Functions) - functionGenerator := legacyToNewSqlVertexGenerator[schema.Function, functionDiff](&functionSQLVertexGenerator{ - functionsInNewSchemaByName: functionsInNewSchemaByName, - }) + functionGenerator := newFunctionSqlVertexGenerator(functionsInNewSchemaByName) functionsPartialGraph, err := generatePartialGraph(functionGenerator, diff.functionDiffs) if err != nil { return nil, fmt.Errorf("resolving function diff: %w", err) } partialGraph = concatPartialGraphs(partialGraph, functionsPartialGraph) + procedureGenerator := newProcedureSqlVertexGenerator(diff.new) + proceduresPartialGraph, err := generatePartialGraph(procedureGenerator, diff.proceduresDiffs) + if err != nil { + return nil, fmt.Errorf("resolving procedure diff: %w", err) + } + partialGraph = concatPartialGraphs(partialGraph, proceduresPartialGraph) + triggerGenerator := legacyToNewSqlVertexGenerator[schema.Trigger, triggerDiff](&triggerSQLVertexGenerator{ functionsInNewSchemaByName: functionsInNewSchemaByName, }) @@ -2546,104 +2562,6 @@ func (e *extensionSQLGenerator) Alter(diff extensionDiff) ([]Statement, error) { return statements, nil } -type functionSQLVertexGenerator struct { - // functionsInNewSchemaByName is a map of function new to functions in the new schema. - // These functions are not necessarily new - functionsInNewSchemaByName map[string]schema.Function -} - -func (f *functionSQLVertexGenerator) Add(function schema.Function) ([]Statement, error) { - var hazards []MigrationHazard - if !canFunctionDependenciesBeTracked(function) { - hazards = append(hazards, migrationHazardAddAlterFunctionCannotTrackDependencies) - } - return []Statement{{ - DDL: function.FunctionDef, - Timeout: statementTimeoutDefault, - LockTimeout: lockTimeoutDefault, - Hazards: hazards, - }}, nil -} - -func (f *functionSQLVertexGenerator) Delete(function schema.Function) ([]Statement, error) { - var hazards []MigrationHazard - if !canFunctionDependenciesBeTracked(function) { - hazards = append(hazards, MigrationHazard{ - Type: MigrationHazardTypeHasUntrackableDependencies, - Message: "Dependencies, i.e. other functions used in the function body, of non-sql functions cannot be " + - "tracked. As a result, we cannot guarantee that function dependencies are ordered properly relative to " + - "this statement. For drops, this means you need to ensure that all functions this function depends on " + - "are dropped after this statement.", - }) - } - return []Statement{{ - DDL: fmt.Sprintf("DROP FUNCTION %s", function.GetFQEscapedName()), - Timeout: statementTimeoutDefault, - LockTimeout: lockTimeoutDefault, - Hazards: hazards, - }}, nil -} - -func (f *functionSQLVertexGenerator) Alter(diff functionDiff) ([]Statement, error) { - // We are assuming the function has been normalized, i.e., we don't have to worry DependsOnFunctions ordering - // causing a false positive diff detected. - if cmp.Equal(diff.old, diff.new) { - return nil, nil - } - - var hazards []MigrationHazard - if !canFunctionDependenciesBeTracked(diff.new) { - hazards = append(hazards, migrationHazardAddAlterFunctionCannotTrackDependencies) - } - return []Statement{{ - DDL: diff.new.FunctionDef, - Timeout: statementTimeoutDefault, - LockTimeout: lockTimeoutDefault, - Hazards: hazards, - }}, nil -} - -func canFunctionDependenciesBeTracked(function schema.Function) bool { - return function.Language == "sql" -} - -func (f *functionSQLVertexGenerator) GetSQLVertexId(function schema.Function, diffType diffType) sqlVertexId { - return buildFunctionVertexId(function.SchemaQualifiedName, diffType) -} - -func buildFunctionVertexId(name schema.SchemaQualifiedName, diffType diffType) sqlVertexId { - return buildSchemaObjVertexId("function", name.GetFQEscapedName(), diffType) -} - -func (f *functionSQLVertexGenerator) GetAddAlterDependencies(newFunction, oldFunction schema.Function) ([]dependency, error) { - // Since functions can just be `CREATE OR REPLACE`, there will never be a case where a function is - // added and dropped in the same migration. Thus, we don't need a dependency on the delete vertex of a function - // because there won't be one if it is being added/altered - var deps []dependency - for _, depFunction := range newFunction.DependsOnFunctions { - deps = append(deps, mustRun(f.GetSQLVertexId(newFunction, diffTypeAddAlter)).after(buildFunctionVertexId(depFunction, diffTypeAddAlter))) - } - - if !cmp.Equal(oldFunction, schema.Function{}) { - // If the function is being altered: - // If the old version of the function calls other functions that are being deleted come, those deletions - // must come after the function is altered, so it is no longer dependent on those dropped functions - for _, depFunction := range oldFunction.DependsOnFunctions { - deps = append(deps, mustRun(f.GetSQLVertexId(newFunction, diffTypeAddAlter)).before(buildFunctionVertexId(depFunction, diffTypeDelete))) - } - } - - return deps, nil -} - -func (f *functionSQLVertexGenerator) GetDeleteDependencies(function schema.Function) ([]dependency, error) { - var deps []dependency - for _, depFunction := range function.DependsOnFunctions { - deps = append(deps, mustRun(f.GetSQLVertexId(function, diffTypeDelete)).before(buildFunctionVertexId(depFunction, diffTypeDelete))) - } - return deps, nil -} - type triggerSQLVertexGenerator struct { // functionsInNewSchemaByName is a map of function new to functions in the new schema. // These functions are not necessarily new