@@ -63,6 +63,32 @@ func getRelationshipPath(schema, name string) string {
6363 return filepath .Join (utils .SchemasDir , schema , "relationships" , name + ".sql" )
6464}
6565
66+ func getTableOrSequencePath (schema , name string , fsys afero.Fs ) string {
67+ // Tables may be renamed such that its sequence id doesn't contain the table name
68+ if table , found := strings .CutSuffix (name , "_id_seq" ); found {
69+ fp := getTablePath (schema , table )
70+ if exists , _ := afero .Exists (fsys , fp ); exists {
71+ return fp
72+ }
73+ }
74+ return getSequencesPath (schema )
75+ }
76+
77+ func getTableOrViewPath (schema , name string , fsys afero.Fs ) string {
78+ // View grants can use table keyword
79+ candidates := []string {
80+ getForeignTablePath (schema , name ),
81+ getViewPath (schema , name ),
82+ getMaterializedViewPath (schema , name ),
83+ }
84+ for _ , fp := range candidates {
85+ if exists , _ := afero .Exists (fsys , fp ); exists {
86+ return fp
87+ }
88+ }
89+ return getTablePath (schema , name )
90+ }
91+
6692func WriteStructuredSchemas (ctx context.Context , sql string , fsys afero.Fs ) error {
6793 stat , err := parser .ParseSQL (sql )
6894 if err != nil {
@@ -180,28 +206,25 @@ func WriteStructuredSchemas(ctx context.Context, sql string, fsys afero.Fs) erro
180206 case * ast.AlterTableStmt :
181207 if r := v .Relation ; r != nil && len (r .SchemaName ) > 0 {
182208 name = getTablePath (r .SchemaName , r .RelName )
183- // TODO: alter sequence / view statements may be parsed to wrong ast
209+ // TODO: alter sequence / view owner may be parsed to wrong ast
184210 switch v .Objtype {
185- case ast .OBJECT_TABLE :
211+ case ast .OBJECT_SEQUENCE :
212+ name = getSequencesPath (r .SchemaName )
213+ case ast .OBJECT_VIEW :
214+ name = getViewPath (r .SchemaName , r .RelName )
215+ default :
186216 if c := v .Cmds ; c != nil {
187217 for _ , e := range c .Items {
188- if n , ok := e .(* ast.AlterTableCmd ); ok {
189- switch n .Subtype {
190- case ast .AT_AddConstraint , ast .AT_AddConstraintRecurse , ast .AT_ReAddConstraint , ast .AT_ReAddDomainConstraint , ast .AT_AlterConstraint , ast .AT_ValidateConstraint , ast .AT_AddIndexConstraint , ast .AT_DropConstraint :
191- if t , ok := n .Def .(* ast.Constraint ); ok {
192- switch t .Contype {
193- case ast .CONSTR_CHECK , ast .CONSTR_FOREIGN :
194- name = getRelationshipPath (r .SchemaName , r .RelName )
195- }
218+ if t , ok := e .(* ast.AlterTableCmd ); ok {
219+ if n , ok := t .Def .(* ast.Constraint ); ok {
220+ switch n .Contype {
221+ case ast .CONSTR_FOREIGN :
222+ name = getRelationshipPath (r .SchemaName , r .RelName )
196223 }
197224 }
198225 }
199226 }
200227 }
201- case ast .OBJECT_SEQUENCE :
202- name = getSequencesPath (r .SchemaName )
203- case ast .OBJECT_VIEW :
204- name = getViewPath (r .SchemaName , r .RelName )
205228 }
206229 }
207230 case * ast.CreateForeignTableStmt :
@@ -219,6 +242,10 @@ func WriteStructuredSchemas(ctx context.Context, sql string, fsys afero.Fs) erro
219242 case * ast.ViewStmt :
220243 if r := v .View ; r != nil && len (r .SchemaName ) > 0 {
221244 name = getViewPath (r .SchemaName , r .RelName )
245+ // Adjust for forward declaration of views
246+ if exists , _ := afero .Exists (fsys , name ); exists {
247+ name = name [:len (name )- 4 ] + "-final.sql"
248+ }
222249 }
223250 case * ast.CreateSeqStmt :
224251 if r := v .Sequence ; r != nil && len (r .SchemaName ) > 0 {
@@ -267,7 +294,9 @@ func WriteStructuredSchemas(ctx context.Context, sql string, fsys afero.Fs) erro
267294 }
268295 }
269296 case * ast.CreateTriggerStmt :
270- if s := toQualifiedName (v .Funcname ); len (s ) == 2 {
297+ if r := v .Relation ; r != nil && len (r .SchemaName ) > 0 {
298+ name = getRelationshipPath (r .SchemaName , r .RelName )
299+ } else if s := toQualifiedName (v .Funcname ); len (s ) == 2 {
271300 name = getFunctionPath (s [0 ], s [1 ])
272301 }
273302 case * ast.CreatePLangStmt :
@@ -280,16 +309,16 @@ func WriteStructuredSchemas(ctx context.Context, sql string, fsys afero.Fs) erro
280309 }
281310 // Schema level entities - others
282311 case * ast.CommentStmt :
283- if s := getNodePath (v .Objtype , v .Object ); len (s ) > 0 {
312+ if s := getNodePath (v .Objtype , v .Object , fsys ); len (s ) > 0 {
284313 name = s
285314 }
286315 case * ast.AlterOwnerStmt :
287- if s := getNodePath (v .ObjectType , v .Object ); len (s ) > 0 {
316+ if s := getNodePath (v .ObjectType , v .Object , fsys ); len (s ) > 0 {
288317 name = s
289318 }
290319 case * ast.GrantStmt :
291320 if n := v .Objects ; n != nil && len (n .Items ) == 1 {
292- if s := getNodePath (v .Objtype , n .Items [0 ]); len (s ) > 0 {
321+ if s := getNodePath (v .Objtype , n .Items [0 ], fsys ); len (s ) > 0 {
293322 name = s
294323 }
295324 }
@@ -312,6 +341,13 @@ func WriteStructuredSchemas(ctx context.Context, sql string, fsys afero.Fs) erro
312341 fmt .Fprintf (utils .GetDebugLogger (), "Unqualified (%T): %s\n " , s , s .SqlString ())
313342 } else if strings .HasPrefix (name , utils .SchemasDir ) {
314343 schemaPaths = append (schemaPaths , name )
344+ if filepath .Base (name ) == "schema.sql" {
345+ schema := filepath .Base (filepath .Dir (name ))
346+ schemaPaths = append (schemaPaths ,
347+ getTypesPath (schema ),
348+ getSequencesPath (schema ),
349+ )
350+ }
315351 }
316352 if err := appendFile (name , s .SqlString ()+ ";\n " , fsys ); err != nil {
317353 return err
@@ -330,7 +366,7 @@ func WriteStructuredSchemas(ctx context.Context, sql string, fsys afero.Fs) erro
330366 return nil
331367}
332368
333- func getNodePath (obj ast.ObjectType , n ast.Node ) string {
369+ func getNodePath (obj ast.ObjectType , n ast.Node , fsys afero. Fs ) string {
334370 switch obj {
335371 // case ast.OBJECT_ACCESS_METHOD:
336372 // case ast.OBJECT_AGGREGATE:
@@ -417,7 +453,7 @@ func getNodePath(obj ast.ObjectType, n ast.Node) string {
417453 }
418454 case ast .OBJECT_SEQUENCE :
419455 if s , ok := n .(* ast.RangeVar ); ok {
420- return getSequencesPath (s .SchemaName )
456+ return getTableOrSequencePath (s .SchemaName , s . RelName , fsys )
421457 }
422458 case ast .OBJECT_SUBSCRIPTION :
423459 return subscriptionsPath
@@ -431,10 +467,10 @@ func getNodePath(obj ast.ObjectType, n ast.Node) string {
431467 case ast .OBJECT_TABLE :
432468 if nl , ok := n .(* ast.NodeList ); ok {
433469 if s := toQualifiedName (nl ); len (s ) == 2 {
434- return getTablePath (s [0 ], s [1 ])
470+ return getTableOrViewPath (s [0 ], s [1 ], fsys )
435471 }
436472 } else if r , ok := n .(* ast.RangeVar ); ok && len (r .SchemaName ) > 0 {
437- return getTablePath (r .SchemaName , r .RelName )
473+ return getTableOrViewPath (r .SchemaName , r .RelName , fsys )
438474 }
439475 case ast .OBJECT_TABLESPACE :
440476 return tablespacesPath
0 commit comments