Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions cmd/migrate_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,14 @@ func runPlanAndApplyTest(t *testing.T, ctx context.Context, container *struct {
t.Fatalf("Failed to create test database %s: %v", dbName, err)
}

// Extensions are cluster-level and persist across tests on the shared
// embedded postgres instance. Register the teardown unconditionally —
// extensions can come from setup.sql, old.sql, or new.sql (the new.sql
// path matters for the create_extension fixture, which has no setup.sql).
t.Cleanup(func() {
cleanupSharedClusterExtensions(t)
})

// STEP 0: Execute optional setup.sql (for cross-schema setup, extension types, etc.)
if _, err := os.Stat(tc.setupFile); err == nil {
setupContent, err := os.ReadFile(tc.setupFile)
Expand Down Expand Up @@ -576,3 +584,40 @@ func matchesFilter(relPath, filter string) bool {
// Fallback: check if filter is a substring of the path
return strings.Contains(relPath, filter)
}

// cleanupSharedClusterExtensions drops any extensions on the shared embedded
// postgres instance other than the built-in `plpgsql`. Extensions are
// cluster-level state and survive per-database resets, so without this teardown
// a setup.sql that installs (say) btree_gist or hstore would leak into every
// subsequent test that inspects the cluster.
func cleanupSharedClusterExtensions(t *testing.T) {
t.Helper()
if sharedEmbeddedPG == nil {
return
}
conn, _, _, _, _, _ := testutil.ConnectToPostgres(t, sharedEmbeddedPG)
defer conn.Close()

ctx := context.Background()
rows, err := conn.QueryContext(ctx, "SELECT extname FROM pg_extension WHERE extname <> 'plpgsql'")
if err != nil {
t.Logf("extension cleanup: query failed (continuing): %v", err)
return
}
var names []string
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
t.Logf("extension cleanup: scan failed (continuing): %v", err)
continue
}
names = append(names, name)
}
rows.Close()

for _, name := range names {
if _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP EXTENSION IF EXISTS %q CASCADE", name)); err != nil {
t.Logf("extension cleanup: failed to drop %s (continuing): %v", name, err)
}
}
}
37 changes: 37 additions & 0 deletions internal/diff/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ const (
DiffTypePrivilege
DiffTypeRevokedDefaultPrivilege
DiffTypeColumnPrivilege
DiffTypeExtension
)

// String returns the string representation of DiffType
Expand Down Expand Up @@ -103,6 +104,8 @@ func (d DiffType) String() string {
return "revoked_default_privilege"
case DiffTypeColumnPrivilege:
return "column_privilege"
case DiffTypeExtension:
return "extension"
default:
return "unknown"
}
Expand Down Expand Up @@ -177,6 +180,8 @@ func (d *DiffType) UnmarshalJSON(data []byte) error {
*d = DiffTypeRevokedDefaultPrivilege
case "column_privilege":
*d = DiffTypeColumnPrivilege
case "extension":
*d = DiffTypeExtension
default:
return fmt.Errorf("unknown diff type: %s", s)
}
Expand Down Expand Up @@ -296,6 +301,9 @@ type ddlDiff struct {
addedColumnPrivileges []*ir.ColumnPrivilege
droppedColumnPrivileges []*ir.ColumnPrivilege
modifiedColumnPrivileges []*columnPrivilegeDiff
// Cluster-level extensions
addedExtensions []*ir.Extension
droppedExtensions []*ir.Extension
}

// schemaDiff represents changes to a schema
Expand Down Expand Up @@ -460,6 +468,27 @@ func GenerateMigration(oldIR, newIR *ir.IR, targetSchema string) []Diff {
addedColumnPrivileges: []*ir.ColumnPrivilege{},
droppedColumnPrivileges: []*ir.ColumnPrivilege{},
modifiedColumnPrivileges: []*columnPrivilegeDiff{},
addedExtensions: []*ir.Extension{},
droppedExtensions: []*ir.Extension{},
}

// Compute extension diffs (cluster-level, so no schema filtering).
// Modifications (version bumps) are out of scope for this initial PR; only
// added/dropped are tracked. See #436 for the broader extension story.
{
extNames := sortedKeys(newIR.Extensions)
for _, name := range extNames {
newExt := newIR.Extensions[name]
if _, exists := oldIR.Extensions[name]; !exists {
diff.addedExtensions = append(diff.addedExtensions, newExt)
}
}
oldExtNames := sortedKeys(oldIR.Extensions)
for _, name := range oldExtNames {
if _, exists := newIR.Extensions[name]; !exists {
diff.droppedExtensions = append(diff.droppedExtensions, oldIR.Extensions[name])
}
Comment on lines +486 to +490
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Avoid global extension drops

This compares all database extensions and drops any extension missing from the desired IR, even when the command is scoped to one target schema. If the current database has an extension used by another schema or application and the desired SQL for public does not declare it, a schema-scoped plan can emit DROP EXTENSION for unrelated database state.

Context Used: CLAUDE.md (source)

Comment on lines +478 to +490
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Detect extension changes

Extension diffing only checks whether the name exists in both IRs. The IR now carries Version, Schema, and Comment, but changes to those fields produce no diff, so pgschema plan can report no changes while the installed extension metadata still differs from the desired state.

Context Used: CLAUDE.md (source)

}
}

// Compare schemas first in deterministic order
Expand Down Expand Up @@ -1499,6 +1528,10 @@ func (d *ddlDiff) generatePreDropMaterializedViewsSQL(targetSchema string, colle
func (d *ddlDiff) generateCreateSQL(targetSchema string, collector *diffCollector) {
// Note: Schema creation is out of scope for schema-level comparisons

// Extensions first: they provide operator classes, types, and functions that
// downstream schema objects (e.g., a GIST index on UUID via btree_gist) depend on.
generateCreateExtensionsSQL(d.addedExtensions, collector)

// Build function lookup early - needed for both domain and table dependency checks
newFunctionLookup := buildFunctionLookup(d.addedFunctions)

Expand Down Expand Up @@ -1721,6 +1754,10 @@ func (d *ddlDiff) generateDropSQL(targetSchema string, collector *diffCollector,
// Drop types
generateDropTypesSQL(d.droppedTypes, targetSchema, collector)

// Drop extensions last: any schema object that depended on the extension
// must already be gone before we try to drop the extension itself.
generateDropExtensionsSQL(d.droppedExtensions, collector)

// Drop schemas
// Note: Schema deletion is out of scope for schema-level comparisons
}
Expand Down
60 changes: 60 additions & 0 deletions internal/diff/extension.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package diff

import (
"fmt"

"github.com/pgplex/pgschema/ir"
)

// generateCreateExtensionsSQL generates `CREATE EXTENSION IF NOT EXISTS` statements
// for newly added extensions. Emitted before any schema-level objects because
// extensions can provide operator classes, types, and functions that those
// objects depend on (e.g., a GIST index using btree_gist's UUID operator class).
func generateCreateExtensionsSQL(extensions []*ir.Extension, collector *diffCollector) {
for _, ext := range extensions {
sql := generateExtensionSQL(ext)
context := &diffContext{
Type: DiffTypeExtension,
Operation: DiffOperationCreate,
Path: extensionPath(ext),
Source: ext,
CanRunInTransaction: true,
}
collector.collect(context, sql)
}
}

// generateDropExtensionsSQL generates `DROP EXTENSION IF EXISTS` statements
// for extensions removed from the target. Emitted after all schema-level drops
// to avoid dependency conflicts.
func generateDropExtensionsSQL(extensions []*ir.Extension, collector *diffCollector) {
for _, ext := range extensions {
context := &diffContext{
Type: DiffTypeExtension,
Operation: DiffOperationDrop,
Path: extensionPath(ext),
Source: ext,
CanRunInTransaction: true,
}
collector.collect(context, fmt.Sprintf("DROP EXTENSION IF EXISTS %s;", ir.QuoteIdentifier(ext.Name)))
}
}

// extensionPath returns the identifier used in the diff Path field. Extensions
// are cluster-level so no schema qualifier is included; doing so would leak
// the plan command's temporary schema into the recorded plan and break
// golden-output stability across runs.
func extensionPath(ext *ir.Extension) string {
return ext.Name
}

// generateExtensionSQL renders a single CREATE EXTENSION statement.
// Extensions are cluster-level; the installed schema is intentionally not
// emitted here. Honoring it would require either pinning it to the user's
// declared value (which we cannot recover from pg_extension alone — the plan
// command's temporary schema becomes the install schema when no WITH SCHEMA
// is given) or filtering out transient schemas. Preserving the user-declared
// install schema is tracked as a follow-up to #436.
func generateExtensionSQL(ext *ir.Extension) string {
return fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS %s;", ir.QuoteIdentifier(ext.Name))
}
Comment on lines +58 to +60
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Quote extension identifiers

ext.Name is emitted directly into the CREATE EXTENSION statement. Valid PostgreSQL extension names can require quoting, such as uuid-ossp; this renders as CREATE EXTENSION IF NOT EXISTS uuid-ossp;, which PostgreSQL does not parse as that extension name. A dump or plan containing that common extension will fail to apply.

Suggested change
func generateExtensionSQL(ext *ir.Extension) string {
return fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS %s;", ext.Name)
}
func generateExtensionSQL(ext *ir.Extension) string {
return fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS %s;", ir.QuoteIdentifier(ext.Name))
}

Context Used: CLAUDE.md (source)

25 changes: 25 additions & 0 deletions ir/inspector.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ func (i *Inspector) BuildIR(ctx context.Context, targetSchema string) (*IR, erro
return nil, fmt.Errorf("failed to build metadata: %w", err)
}

if err := i.buildExtensions(ctx, schema); err != nil {
return nil, fmt.Errorf("failed to build extensions: %w", err)
}

if err := i.validateSchemaExists(ctx, targetSchema); err != nil {
return nil, err
}
Expand Down Expand Up @@ -207,6 +211,27 @@ func (i *Inspector) buildMetadata(ctx context.Context, schema *IR) error {
return nil
}

// buildExtensions records every installed extension (except plpgsql) on the IR.
// Extensions are cluster-level — they are not scoped by targetSchema.
func (i *Inspector) buildExtensions(ctx context.Context, schema *IR) error {
rows, err := i.queries.GetExtensions(ctx)
if err != nil {
return err
}
for _, row := range rows {
if !row.ExtensionName.Valid {
continue
}
schema.SetExtension(&Extension{
Name: row.ExtensionName.String,
Version: row.ExtensionVersion,
Schema: row.ExtensionSchema.String,
Comment: row.ExtensionComment.String,
})
}
return nil
}

func (i *Inspector) buildSchemas(ctx context.Context, schema *IR, targetSchema string) error {
// Use the schema-specific query to prefilter at the database level
schemaName, err := i.queries.GetSchema(ctx, sql.NullString{String: targetSchema, Valid: true})
Expand Down
39 changes: 35 additions & 4 deletions ir/ir.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,20 @@ import (

// IR represents the complete database schema intermediate representation
type IR struct {
Metadata Metadata `json:"metadata"`
Schemas map[string]*Schema `json:"schemas"` // schema_name -> Schema
mu sync.RWMutex // Protects concurrent access to Schemas
Metadata Metadata `json:"metadata"`
Extensions map[string]*Extension `json:"extensions,omitempty"` // extension_name -> Extension (cluster-level, not per-schema)
Schemas map[string]*Schema `json:"schemas"` // schema_name -> Schema
mu sync.RWMutex // Protects concurrent access to Schemas and Extensions
}

// Extension represents a PostgreSQL extension installed in the database.
// Extensions are cluster-level (installed once per database), so they live at
// the IR root rather than under a Schema.
type Extension struct {
Name string `json:"name"` // e.g., "btree_gist"
Version string `json:"version,omitempty"` // e.g., "1.7"
Schema string `json:"schema,omitempty"` // Namespace where the extension's default objects are installed
Comment string `json:"comment,omitempty"`
}

// Metadata contains information about the schema dump
Expand Down Expand Up @@ -542,10 +553,29 @@ func (cp *ColumnPrivilege) GetObjectName() string {
// NewIR creates a new empty catalog IR
func NewIR() *IR {
return &IR{
Schemas: make(map[string]*Schema),
Schemas: make(map[string]*Schema),
Extensions: make(map[string]*Extension),
}
}

// SetExtension records an extension on the IR with thread safety.
func (c *IR) SetExtension(ext *Extension) {
c.mu.Lock()
defer c.mu.Unlock()
if c.Extensions == nil {
c.Extensions = make(map[string]*Extension)
}
c.Extensions[ext.Name] = ext
}

// GetExtension retrieves an extension by name with thread safety.
func (c *IR) GetExtension(name string) (*Extension, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
ext, ok := c.Extensions[name]
return ext, ok
}

// GetSchema retrieves a schema by name with thread safety.
// Returns the schema and true if found, or nil and false if not found.
func (c *IR) GetSchema(name string) (*Schema, bool) {
Expand Down Expand Up @@ -709,4 +739,5 @@ func (p *Procedure) GetObjectName() string { return p.Name }
func (v *View) GetObjectName() string { return v.Name }
func (s *Sequence) GetObjectName() string { return s.Name }
func (t *Type) GetObjectName() string { return t.Name }
func (e *Extension) GetObjectName() string { return e.Name }

17 changes: 16 additions & 1 deletion ir/queries/queries.sql
Original file line number Diff line number Diff line change
Expand Up @@ -1446,4 +1446,19 @@ JOIN pg_namespace referenced_ns ON referenced_proc.pronamespace = referenced_ns.
WHERE d.classid = 'pg_proc'::regclass
AND d.refclassid = 'pg_proc'::regclass
AND d.deptype = 'n'
AND dependent_ns.nspname = $1;
AND dependent_ns.nspname = $1;

-- GetExtensions retrieves all installed extensions except the always-present
-- `plpgsql` built-in. Used to render `CREATE EXTENSION` statements in the dump
-- so dumps remain replayable on a fresh database.
-- name: GetExtensions :many
SELECT
e.extname::text AS extension_name,
e.extversion::text AS extension_version,
n.nspname::text AS extension_schema,
COALESCE(d.description, '') AS extension_comment
FROM pg_extension e
JOIN pg_namespace n ON e.extnamespace = n.oid
LEFT JOIN pg_description d ON d.objoid = e.oid AND d.classoid = 'pg_extension'::regclass
WHERE e.extname != 'plpgsql'
ORDER BY e.extname;
51 changes: 51 additions & 0 deletions ir/queries/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions testdata/diff/create_extension/add_extension/diff.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE EXTENSION IF NOT EXISTS btree_gist;
1 change: 1 addition & 0 deletions testdata/diff/create_extension/add_extension/new.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE EXTENSION IF NOT EXISTS btree_gist;
1 change: 1 addition & 0 deletions testdata/diff/create_extension/add_extension/old.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
-- Empty schema (no extensions declared)
Loading