diff --git a/snowflake/analysis/query_span_test.go b/snowflake/analysis/query_span_test.go index 01670844..0523c219 100644 --- a/snowflake/analysis/query_span_test.go +++ b/snowflake/analysis/query_span_test.go @@ -559,3 +559,90 @@ func TestQuerySpan_ParseError(t *testing.T) { t.Error("expected error for invalid SQL, got nil") } } + +// --------------------------------------------------------------------------- +// Test 19: CASE expression columns are attributed (walker coverage regression) +// --------------------------------------------------------------------------- + +// The generated AST walker used to skip CaseExpr.Whens (a []*WhenClause of a +// non-Node helper struct), so the WHEN condition and THEN result columns were +// invisible to ast.Inspect and never attributed to the result column. +func TestQuerySpan_CaseExpressionColumns(t *testing.T) { + span := mustExtract(t, "SELECT CASE WHEN a > 0 THEN b ELSE c END AS r FROM t") + + if len(span.Results) != 1 { + t.Fatalf("Results: got %d, want 1", len(span.Results)) + } + rc := span.Results[0] + if rc.Name != "R" { + t.Errorf("result name: got %q, want %q", rc.Name, "R") + } + if !rc.IsDerived { + t.Error("IsDerived should be true for a CASE expression") + } + keys := resultSourceKeys(span, 0) + want := []string{"..T.A", "..T.B", "..T.C"} + if len(keys) != len(want) { + t.Fatalf("result sources: got %v, want %v", keys, want) + } + for i := range want { + if keys[i] != want[i] { + t.Errorf("source[%d]: got %q, want %q", i, keys[i], want[i]) + } + } +} + +// --------------------------------------------------------------------------- +// Test 20: window function columns are attributed (walker coverage regression) +// --------------------------------------------------------------------------- + +// FuncCallExpr.Over (*WindowSpec, a non-Node helper struct) used to be skipped +// by the walker, hiding PARTITION BY / ORDER BY columns from the span. +func TestQuerySpan_WindowFunctionColumns(t *testing.T) { + span := mustExtract(t, "SELECT ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary) AS rn FROM emp") + + if len(span.Results) != 1 { + t.Fatalf("Results: got %d, want 1", len(span.Results)) + } + rc := span.Results[0] + if rc.Name != "RN" { + t.Errorf("result name: got %q, want %q", rc.Name, "RN") + } + if !rc.IsDerived { + t.Error("IsDerived should be true for a window function") + } + keys := resultSourceKeys(span, 0) + want := []string{"..EMP.DEPT", "..EMP.SALARY"} + if len(keys) != len(want) { + t.Fatalf("result sources: got %v, want %v", keys, want) + } + for i := range want { + if keys[i] != want[i] { + t.Errorf("source[%d]: got %q, want %q", i, keys[i], want[i]) + } + } +} + +// --------------------------------------------------------------------------- +// Test 21: WITHIN GROUP order columns are attributed (walker coverage regression) +// --------------------------------------------------------------------------- + +// FuncCallExpr.OrderBy ([]*OrderItem, non-Node helper structs) used to be +// skipped by the walker, hiding WITHIN GROUP (ORDER BY ...) columns. +func TestQuerySpan_WithinGroupColumns(t *testing.T) { + span := mustExtract(t, "SELECT LISTAGG(name, ',') WITHIN GROUP (ORDER BY pos) AS l FROM t") + + if len(span.Results) != 1 { + t.Fatalf("Results: got %d, want 1", len(span.Results)) + } + keys := resultSourceKeys(span, 0) + want := []string{"..T.NAME", "..T.POS"} + if len(keys) != len(want) { + t.Fatalf("result sources: got %v, want %v", keys, want) + } + for i := range want { + if keys[i] != want[i] { + t.Errorf("source[%d]: got %q, want %q", i, keys[i], want[i]) + } + } +} diff --git a/snowflake/ast/cmd/genwalker/main.go b/snowflake/ast/cmd/genwalker/main.go index 15fa350d..9c02dbd8 100644 --- a/snowflake/ast/cmd/genwalker/main.go +++ b/snowflake/ast/cmd/genwalker/main.go @@ -1,8 +1,27 @@ // Command genwalker generates walk_generated.go from parsenodes.go and node.go. // -// It scans all struct types in those files, identifies fields whose types are -// Node (interface), []Node, or pointers to other AST structs, and generates -// the walkChildren function that enumerates child nodes for Walk. +// It scans all struct types in those files, identifies fields that carry child +// nodes, and generates the walkChildren function that enumerates child nodes +// for Walk. +// +// A field carries child nodes when its type is one of: +// +// - Node, []Node, [][]Node — the Node interface and slices of it +// - *T, []*T, T, []T — where T is a struct implementing Node +// (has a Tag() method) +// - *H, []*H, H, []H — where H is a NON-Node helper struct +// (e.g. WhenClause, OrderItem, WindowSpec) that +// transitively holds node-carrying fields +// +// For each node-bearing helper struct a dedicated walk function is +// generated that walks the helper's own node-bearing fields (recursing into +// nested helpers), so expressions stored inside helpers — CASE WHEN +// conditions, OVER (PARTITION BY ... ORDER BY ...) expressions, CTE bodies, +// SELECT-list targets, and the like — are reachable from Walk / Inspect. +// +// If a field's type mentions Node, a Node struct, or a node-bearing helper in +// a shape the generator does not understand (e.g. a map or [][]*T), generation +// fails loudly rather than silently skipping the field. // // Usage: // @@ -25,6 +44,34 @@ import ( "strings" ) +// shape classifies how a field's type carries child nodes. +type shape int + +const ( + shapeNone shape = iota // does not carry nodes + shapeNodeIface // Node + shapeNodeIfaceSlice // []Node + shapeNodeIfaceRows // [][]Node + shapeNodePtr // *T, T a Node struct + shapeNodePtrSlice // []*T, T a Node struct + shapeNodeValue // T, T a Node struct (by value) + shapeNodeValueSlice // []T, T a Node struct (by value) + shapeHelperPtr // *H, H a node-bearing helper struct + shapeHelperPtrSlice // []*H, H a node-bearing helper struct + shapeHelperValue // H, H a node-bearing helper struct (by value) + shapeHelperValueSlice // []H, H a node-bearing helper struct (by value) +) + +type field struct { + Name string + Type string // "Node", "[]Node", "*SelectStmt", "[]*WhenClause", etc. +} + +type structInfo struct { + Name string + Fields []field +} + func main() { fset := token.NewFileSet() @@ -42,45 +89,22 @@ func main() { files = append(files, f) } - // Collect all struct type names so we can detect pointer-to-known-struct - // fields below. - structNames := map[string]bool{} - type field struct { - Name string - Type string // "Node", "[]Node", "*SelectStmt", etc. - } - type structInfo struct { - Name string - Fields []field - } - - var structs []structInfo - - // nodeStructs tracks which structs implement the Node interface (have a Tag() method). - // Only these structs get their own case in the walkChildren switch, and only - // pointer-to-Node-struct fields generate Walk calls. + // nodeStructs tracks which structs implement the Node interface (have a + // Tag() method). Only these structs get their own case in the walkChildren + // switch. nodeStructs := map[string]bool{} + // structFields records every struct's full field list (all fields, not + // just node-bearing ones) so helper-struct reachability can be computed. + structFields := map[string][]field{} for _, f := range files { - for _, decl := range f.Decls { - gd, ok := decl.(*ast.GenDecl) - if !ok || gd.Tok != token.TYPE { - continue - } - for _, spec := range gd.Specs { - ts := spec.(*ast.TypeSpec) - if _, ok := ts.Type.(*ast.StructType); ok { - structNames[ts.Name.Name] = true - } - } - } // Scan function declarations for Tag() methods to identify Node types. for _, decl := range f.Decls { fd, ok := decl.(*ast.FuncDecl) if !ok || fd.Recv == nil || fd.Name.Name != "Tag" { continue } - // Check it returns NodeTag and has no parameters. + // Check it returns a single result and has no parameters. if fd.Type.Params != nil && len(fd.Type.Params.List) > 0 { continue } @@ -103,10 +127,7 @@ func main() { } } } - } - - // Collect fields for each struct. - for _, f := range files { + // Collect fields for each struct. for _, decl := range f.Decls { gd, ok := decl.(*ast.GenDecl) if !ok || gd.Tok != token.TYPE { @@ -118,28 +139,75 @@ func main() { if !ok { continue } - var fields []field for _, fld := range st.Fields.List { if len(fld.Names) == 0 { continue // embedded } typStr := typeString(fld.Type) - if isChildType(typStr, nodeStructs) { - for _, name := range fld.Names { - fields = append(fields, field{Name: name.Name, Type: typStr}) - } + for _, name := range fld.Names { + fields = append(fields, field{Name: name.Name, Type: typStr}) } } - structs = append(structs, structInfo{Name: ts.Name.Name, Fields: fields}) + structFields[ts.Name.Name] = fields } } } - // Sort by name for deterministic output. - sort.Slice(structs, func(i, j int) bool { - return structs[i].Name < structs[j].Name - }) + // Compute the set of node-bearing helper structs: non-Node structs that + // transitively hold node-carrying fields. Fixpoint iteration handles + // helper-inside-helper nesting (e.g. WindowSpec → WindowFrame → + // WindowBound). + bearingHelpers := map[string]bool{} + for changed := true; changed; { + changed = false + for name, fields := range structFields { + if nodeStructs[name] || bearingHelpers[name] || name == "Loc" { + continue + } + for _, f := range fields { + if classify(f.Type, nodeStructs, bearingHelpers) != shapeNone { + bearingHelpers[name] = true + changed = true + break + } + } + } + } + + // Guard: any field whose type mentions Node, a Node struct, or a + // node-bearing helper must classify to a known shape — otherwise the + // walker would silently skip child nodes. Fail generation instead. + var unsupported []string + for name, fields := range structFields { + if !nodeStructs[name] && !bearingHelpers[name] { + continue + } + for _, f := range fields { + if classify(f.Type, nodeStructs, bearingHelpers) != shapeNone { + continue + } + b := baseType(f.Type) + if b == "Node" || nodeStructs[b] || bearingHelpers[b] { + unsupported = append(unsupported, fmt.Sprintf("%s.%s %s", name, f.Name, f.Type)) + } + } + } + if len(unsupported) > 0 { + sort.Strings(unsupported) + fmt.Fprintf(os.Stderr, "genwalker: node-bearing fields with unsupported type shapes:\n") + for _, u := range unsupported { + fmt.Fprintf(os.Stderr, " %s\n", u) + } + os.Exit(1) + } + + // Sort node structs by name for deterministic output. + var names []string + for n := range structFields { + names = append(names, n) + } + sort.Strings(names) // Generate code. var buf bytes.Buffer @@ -150,35 +218,54 @@ func main() { buf.WriteString("func walkChildren(v Visitor, node Node) {\n") buf.WriteString("\tswitch n := node.(type) {\n") - for _, s := range structs { - if len(s.Fields) == 0 { - continue - } + cases := 0 + caseFields := 0 + for _, s := range names { // Only emit cases for structs that implement Node (have a Tag() method). - if !nodeStructs[s.Name] { + if !nodeStructs[s] { continue } - fmt.Fprintf(&buf, "\tcase *%s:\n", s.Name) - for _, f := range s.Fields { - switch f.Type { - case "Node": - fmt.Fprintf(&buf, "\t\tWalk(v, n.%s)\n", f.Name) - case "[]Node": - fmt.Fprintf(&buf, "\t\twalkNodes(v, n.%s)\n", f.Name) - case "[][]Node": - fmt.Fprintf(&buf, "\t\twalkNodeRows(v, n.%s)\n", f.Name) - default: - // Pointer to a known struct type (e.g. *SelectStmt). - fmt.Fprintf(&buf, "\t\tif n.%s != nil {\n", f.Name) - fmt.Fprintf(&buf, "\t\t\tWalk(v, n.%s)\n", f.Name) - fmt.Fprintf(&buf, "\t\t}\n") + var walkable []field + for _, f := range structFields[s] { + if classify(f.Type, nodeStructs, bearingHelpers) != shapeNone { + walkable = append(walkable, f) } } + if len(walkable) == 0 { + continue + } + cases++ + caseFields += len(walkable) + fmt.Fprintf(&buf, "\tcase *%s:\n", s) + for _, f := range walkable { + emitFieldWalk(&buf, "\t\t", "n."+f.Name, f.Type, nodeStructs, bearingHelpers) + } } buf.WriteString("\t}\n") buf.WriteString("}\n") + // Emit one walk function per node-bearing helper struct, sorted by name. + helperCount := 0 + for _, s := range names { + if !bearingHelpers[s] { + continue + } + helperCount++ + fmt.Fprintf(&buf, "\n// walk%s walks the node-bearing fields of the non-Node helper struct\n", s) + fmt.Fprintf(&buf, "// %s. The helper itself is not visited (it is not a Node); its child\n", s) + buf.WriteString("// nodes are walked in field order.\n") + fmt.Fprintf(&buf, "func walk%s(v Visitor, n *%s) {\n", s, s) + buf.WriteString("\tif n == nil {\n\t\treturn\n\t}\n") + for _, f := range structFields[s] { + if classify(f.Type, nodeStructs, bearingHelpers) == shapeNone { + continue + } + emitFieldWalk(&buf, "\t", "n."+f.Name, f.Type, nodeStructs, bearingHelpers) + } + buf.WriteString("}\n") + } + // Format with gofmt rules. formatted, err := format.Source(buf.Bytes()) if err != nil { @@ -193,16 +280,102 @@ func main() { os.Exit(1) } - // Stats. - cases := 0 - fields := 0 - for _, s := range structs { - if len(s.Fields) > 0 && nodeStructs[s.Name] { - cases++ - fields += len(s.Fields) - } + fmt.Printf("Generated walk_generated.go: %d cases, %d child fields, %d helper walkers\n", + cases, caseFields, helperCount) +} + +// emitFieldWalk writes the walk statement(s) for one node-bearing field. +// expr is the field access expression (e.g. "n.Whens"), indent is the leading +// whitespace for the first line. +func emitFieldWalk(buf *bytes.Buffer, indent, expr, typStr string, nodeStructs, bearingHelpers map[string]bool) { + switch classify(typStr, nodeStructs, bearingHelpers) { + case shapeNodeIface: + fmt.Fprintf(buf, "%sWalk(v, %s)\n", indent, expr) + case shapeNodeIfaceSlice: + fmt.Fprintf(buf, "%swalkNodes(v, %s)\n", indent, expr) + case shapeNodeIfaceRows: + fmt.Fprintf(buf, "%swalkNodeRows(v, %s)\n", indent, expr) + case shapeNodePtr: + fmt.Fprintf(buf, "%sif %s != nil {\n", indent, expr) + fmt.Fprintf(buf, "%s\tWalk(v, %s)\n", indent, expr) + fmt.Fprintf(buf, "%s}\n", indent) + case shapeNodePtrSlice: + fmt.Fprintf(buf, "%sfor _, item := range %s {\n", indent, expr) + fmt.Fprintf(buf, "%s\tif item != nil {\n", indent) + fmt.Fprintf(buf, "%s\t\tWalk(v, item)\n", indent) + fmt.Fprintf(buf, "%s\t}\n", indent) + fmt.Fprintf(buf, "%s}\n", indent) + case shapeNodeValue: + fmt.Fprintf(buf, "%sWalk(v, &%s)\n", indent, expr) + case shapeNodeValueSlice: + fmt.Fprintf(buf, "%sfor i := range %s {\n", indent, expr) + fmt.Fprintf(buf, "%s\tWalk(v, &%s[i])\n", indent, expr) + fmt.Fprintf(buf, "%s}\n", indent) + case shapeHelperPtr: + fmt.Fprintf(buf, "%swalk%s(v, %s)\n", indent, baseType(typStr), expr) + case shapeHelperPtrSlice: + fmt.Fprintf(buf, "%sfor _, item := range %s {\n", indent, expr) + fmt.Fprintf(buf, "%s\twalk%s(v, item)\n", indent, baseType(typStr)) + fmt.Fprintf(buf, "%s}\n", indent) + case shapeHelperValue: + fmt.Fprintf(buf, "%swalk%s(v, &%s)\n", indent, baseType(typStr), expr) + case shapeHelperValueSlice: + fmt.Fprintf(buf, "%sfor i := range %s {\n", indent, expr) + fmt.Fprintf(buf, "%s\twalk%s(v, &%s[i])\n", indent, baseType(typStr), expr) + fmt.Fprintf(buf, "%s}\n", indent) + } +} + +// classify maps a field type string to the shape of child-node traversal it +// needs. Returns shapeNone for types that carry no nodes (enums, Ident, Loc, +// scalars, inert helper structs). +func classify(typStr string, nodeStructs, bearingHelpers map[string]bool) shape { + switch typStr { + case "Node": + return shapeNodeIface + case "[]Node": + return shapeNodeIfaceSlice + case "[][]Node": + return shapeNodeIfaceRows + } + name := typStr + slice := false + if strings.HasPrefix(name, "[]") { + slice = true + name = name[2:] + } + ptr := false + if strings.HasPrefix(name, "*") { + ptr = true + name = name[1:] + } + // Reject any remaining wrapper (e.g. [][]*T, **T, maps) — the guard in + // main reports node-bearing fields that end up here. + if strings.ContainsAny(name, "[]*. ") { + return shapeNone } - fmt.Printf("Generated walk_generated.go: %d cases, %d child fields\n", cases, fields) + if name == "Loc" { + return shapeNone + } + switch { + case nodeStructs[name] && ptr && slice: + return shapeNodePtrSlice + case nodeStructs[name] && ptr: + return shapeNodePtr + case nodeStructs[name] && slice: + return shapeNodeValueSlice + case nodeStructs[name]: + return shapeNodeValue + case bearingHelpers[name] && ptr && slice: + return shapeHelperPtrSlice + case bearingHelpers[name] && ptr: + return shapeHelperPtr + case bearingHelpers[name] && slice: + return shapeHelperValueSlice + case bearingHelpers[name]: + return shapeHelperValue + } + return shapeNone } // typeString returns the string representation of a Go type expression. @@ -222,33 +395,16 @@ func typeString(expr ast.Expr) string { } } -// isChildType reports whether typStr represents a field that walkChildren -// should descend into. -// -// Recognized shapes: -// - "Node" — the Node interface -// - "[]Node" — slice of nodes -// - "[][]Node" — slice of rows of nodes (e.g. VALUES rows) -// - "*" — pointer to a struct that implements Node (has Tag() method) -// -// Excluded: pointer to "Loc" (Loc is not a node), pointer to non-Node structs -// (helper types like WhenClause, WindowSpec, etc.), and any other shape. -func isChildType(typStr string, nodeStructs map[string]bool) bool { - if typStr == "Node" { - return true - } - if typStr == "[]Node" { - return true - } - if typStr == "[][]Node" { - return true - } - if strings.HasPrefix(typStr, "*") { - name := typStr[1:] - if name == "Loc" { - return false +// baseType strips slice and pointer wrappers: "[]*WhenClause" → "WhenClause". +func baseType(t string) string { + for { + switch { + case strings.HasPrefix(t, "[]"): + t = t[2:] + case strings.HasPrefix(t, "*"): + t = t[1:] + default: + return t } - return nodeStructs[name] } - return false } diff --git a/snowflake/ast/walk_coverage_test.go b/snowflake/ast/walk_coverage_test.go new file mode 100644 index 00000000..4af30799 --- /dev/null +++ b/snowflake/ast/walk_coverage_test.go @@ -0,0 +1,411 @@ +package ast_test + +// Walker coverage regression tests for node-bearing fields whose types are +// not themselves Node (helper structs like WhenClause / OrderItem / +// WindowSpec / CTE / SelectTarget), slices of pointers to Node structs +// ([]*ObjectName, []*ColumnDef, ...), and by-value Node structs +// (FuncCallExpr.Name). The generated walker used to skip ALL of these, so +// ast.Inspect never reached e.g. CASE WHEN conditions or OVER (PARTITION BY +// ... ORDER BY ...) expressions, silently breaking column collection in every +// walker-based consumer (analysis query-span included). +// +// These tests parse real SQL and assert that specific nodes are visited. +// They live in an external test package so they can use the parser. + +import ( + "strings" + "testing" + + "github.com/bytebase/omni/snowflake/ast" + "github.com/bytebase/omni/snowflake/parser" +) + +// walkSummary records everything one ast.Inspect traversal visited. +type walkSummary struct { + tags map[ast.NodeTag]int + colRefs map[string]bool // normalized dotted ColumnRef ("A", "S.MF") + objNames map[string]bool // normalized dotted ObjectName ("MYSCHEMA.MYFUNC") + literals map[string]bool // raw Literal source text ("10", "'JAN'") +} + +// summarize parses sql and walks the resulting file, recording all visits. +func summarize(t *testing.T, sql string) *walkSummary { + t.Helper() + file, err := parser.Parse(sql) + if err != nil { + t.Fatalf("parse %q: %v", sql, err) + } + s := &walkSummary{ + tags: map[ast.NodeTag]int{}, + colRefs: map[string]bool{}, + objNames: map[string]bool{}, + literals: map[string]bool{}, + } + ast.Inspect(file, func(n ast.Node) bool { + if n == nil { + return false + } + s.tags[n.Tag()]++ + switch x := n.(type) { + case *ast.ColumnRef: + parts := make([]string, len(x.Parts)) + for i, p := range x.Parts { + parts[i] = p.Normalize() + } + s.colRefs[strings.Join(parts, ".")] = true + case *ast.ObjectName: + s.objNames[x.Normalize()] = true + case *ast.Literal: + s.literals[x.Value] = true + } + return true + }) + return s +} + +// walkCoverageCase asserts minimum visit evidence for one SQL statement. +type walkCoverageCase struct { + name string + sql string + cols []string // ColumnRefs that must be visited (normalized) + objs []string // ObjectNames that must be visited (normalized) + lits []string // Literals that must be visited (raw text) + tags map[ast.NodeTag]int // minimum visit count per tag +} + +func runWalkCoverage(t *testing.T, cases []walkCoverageCase) { + t.Helper() + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + s := summarize(t, c.sql) + for _, col := range c.cols { + if !s.colRefs[col] { + t.Errorf("ColumnRef %q not visited by Walk; visited: %v", col, keys(s.colRefs)) + } + } + for _, obj := range c.objs { + if !s.objNames[obj] { + t.Errorf("ObjectName %q not visited by Walk; visited: %v", obj, keys(s.objNames)) + } + } + for _, lit := range c.lits { + if !s.literals[lit] { + t.Errorf("Literal %q not visited by Walk; visited: %v", lit, keys(s.literals)) + } + } + for tag, min := range c.tags { + if s.tags[tag] < min { + t.Errorf("tag %v visited %d times, want >= %d", tag, s.tags[tag], min) + } + } + }) + } +} + +func keys(m map[string]bool) []string { + out := make([]string, 0, len(m)) + for k := range m { + out = append(out, k) + } + return out +} + +// --------------------------------------------------------------------------- +// Expression-level gaps (CASE, window functions, WITHIN GROUP, JSON literals) +// --------------------------------------------------------------------------- + +func TestWalkCoverage_Expressions(t *testing.T) { + runWalkCoverage(t, []walkCoverageCase{ + { + name: "CaseExpr.Whens searched", + sql: "SELECT CASE WHEN wa > 0 THEN wb ELSE wc END FROM t", + cols: []string{"WA", "WB", "WC"}, + }, + { + name: "CaseExpr.Whens simple", + sql: "SELECT CASE sx WHEN sy THEN sz END FROM t", + cols: []string{"SX", "SY", "SZ"}, + }, + { + name: "FuncCallExpr.Over partition and order", + sql: "SELECT ROW_NUMBER() OVER (PARTITION BY px ORDER BY oy) FROM t", + cols: []string{"PX", "OY"}, + }, + { + name: "FuncCallExpr.Over frame bound offsets", + sql: "SELECT SUM(sx) OVER (ORDER BY oy ROWS BETWEEN 5 PRECEDING AND 3 FOLLOWING) FROM t", + cols: []string{"SX", "OY"}, + lits: []string{"5", "3"}, + }, + { + name: "FuncCallExpr.OrderBy within group", + sql: "SELECT LISTAGG(lx, ',') WITHIN GROUP (ORDER BY wy) FROM t", + cols: []string{"LX", "WY"}, + }, + { + name: "FuncCallExpr.Name qualified function name", + sql: "SELECT myschema.myfunc(fa) FROM t", + cols: []string{"FA"}, + objs: []string{"MYSCHEMA.MYFUNC"}, + }, + { + name: "JsonLiteralExpr.Pairs values", + sql: "SELECT {'k': jv} FROM t", + cols: []string{"JV"}, + }, + }) +} + +// --------------------------------------------------------------------------- +// SELECT clause gaps (targets, WITH, GROUP BY, ORDER BY, FETCH) +// --------------------------------------------------------------------------- + +func TestWalkCoverage_SelectClauses(t *testing.T) { + runWalkCoverage(t, []walkCoverageCase{ + { + name: "SelectStmt.Targets expressions", + sql: "SELECT ta + tb AS s FROM t", + cols: []string{"TA", "TB"}, + }, + { + name: "SelectStmt.With CTE body", + sql: "WITH c AS (SELECT ca FROM t) SELECT * FROM c", + cols: []string{"CA"}, + tags: map[ast.NodeTag]int{ast.T_SelectStmt: 2}, + }, + { + name: "SelectStmt.GroupBy items", + sql: "SELECT COUNT(*) FROM t GROUP BY ga, gb", + cols: []string{"GA", "GB"}, + }, + { + name: "SelectStmt.OrderBy items", + sql: "SELECT a FROM t ORDER BY ob DESC NULLS LAST", + cols: []string{"OB"}, + }, + { + name: "SelectStmt.Fetch count", + sql: "SELECT a FROM t FETCH FIRST 10 ROWS ONLY", + lits: []string{"10"}, + }, + { + name: "scalar subquery in select list reaches inner select", + sql: "SELECT (SELECT inner_col FROM u) FROM t", + cols: []string{"INNER_COL"}, + tags: map[ast.NodeTag]int{ast.T_SelectStmt: 2}, + }, + }) +} + +// --------------------------------------------------------------------------- +// FROM-clause helper gaps (PIVOT IN values/order, UNPIVOT columns, +// MATCH_RECOGNIZE measures/define/order) +// --------------------------------------------------------------------------- + +func TestWalkCoverage_FromClauses(t *testing.T) { + runWalkCoverage(t, []walkCoverageCase{ + { + name: "PivotInClause.Values", + sql: "SELECT * FROM monthly_sales PIVOT (SUM(amount) FOR month IN ('JAN', 'FEB'))", + cols: []string{"AMOUNT", "MONTH"}, + tags: map[ast.NodeTag]int{ast.T_PivotValue: 2}, + }, + { + name: "PivotInClause.OrderBy", + sql: "SELECT * FROM quarterly_sales PIVOT(SUM(amount) FOR quarter IN (ANY ORDER BY quarter)) ORDER BY empid", + cols: []string{"QUARTER", "EMPID"}, + }, + { + name: "UnpivotClause.Columns", + sql: "SELECT * FROM monthly_sales UNPIVOT (sales FOR month IN (jan, feb))", + tags: map[ast.NodeTag]int{ast.T_UnpivotColumn: 2}, + }, + { + name: "MatchRecognizeClause OrderBy/Measures/Define", + sql: `SELECT * FROM stock_price_history MATCH_RECOGNIZE( + PARTITION BY company + ORDER BY price_date + MEASURES FIRST(price_date) AS start_date + ONE ROW PER MATCH + PATTERN(dn up) + DEFINE dn AS price < LAG(price), up AS price > LAG(price) + )`, + cols: []string{"COMPANY", "PRICE_DATE", "PRICE"}, + }, + }) +} + +// --------------------------------------------------------------------------- +// DML gaps (UPDATE SET, MERGE WHEN, INSERT ALL branches, SET variables) +// --------------------------------------------------------------------------- + +func TestWalkCoverage_DML(t *testing.T) { + runWalkCoverage(t, []walkCoverageCase{ + { + name: "UpdateStmt.Sets value expressions", + sql: "UPDATE t SET c1 = ux + 1 WHERE id = 5", + cols: []string{"UX", "ID"}, + objs: []string{"C1"}, + }, + { + name: "MergeStmt.Whens conditions, sets, and insert values", + sql: `MERGE INTO t USING s ON t.id = s.id + WHEN MATCHED AND s.mf > 0 THEN UPDATE SET c = s.mv + WHEN NOT MATCHED THEN INSERT (a) VALUES (s.mw)`, + cols: []string{"S.MF", "S.MV", "S.MW"}, + }, + { + name: "InsertMultiStmt.Branches targets and values", + sql: `INSERT ALL + INTO t1 VALUES (1, 'a') + INTO t2 VALUES (2, 'b') + SELECT * FROM src`, + objs: []string{"T1", "T2"}, + lits: []string{"1", "2", "a", "b"}, + }, + { + name: "SetStmt.Vars values", + sql: "SET V1 = 10", + lits: []string{"10"}, + }, + }) +} + +// --------------------------------------------------------------------------- +// DDL gaps (column defs, constraints, indexes, clones, tags, policies, +// routine signatures, copy options, stream time travel, ...) +// --------------------------------------------------------------------------- + +func TestWalkCoverage_DDL(t *testing.T) { + runWalkCoverage(t, []walkCoverageCase{ + { + name: "CreateTableStmt.Columns and Constraints with FK references", + sql: "CREATE TABLE t (c1 INT DEFAULT 7, c2 INT, CONSTRAINT fk FOREIGN KEY (c2) REFERENCES p (py))", + lits: []string{"7"}, + objs: []string{"P"}, + tags: map[ast.NodeTag]int{ast.T_ColumnDef: 2, ast.T_TableConstraint: 1, ast.T_TypeName: 2}, + }, + { + name: "ColumnDef.InlineConstraint references", + sql: "CREATE TABLE t (customer_id INT FOREIGN KEY REFERENCES customers (id))", + objs: []string{"CUSTOMERS"}, + }, + { + name: "CreateTableStmt.Indexes (hybrid table)", + sql: "CREATE HYBRID TABLE t (id INT, full_name VARCHAR(255), INDEX index_full_name (full_name))", + tags: map[ast.NodeTag]int{ast.T_TableIndex: 1}, + }, + { + name: "CreateTableStmt.Clone source", + sql: "CREATE TABLE t1 CLONE t2", + objs: []string{"T2"}, + }, + { + name: "CreateViewStmt.RowPolicy", + sql: "CREATE VIEW v WITH ROW ACCESS POLICY my_policy ON (col1, col2) AS SELECT 1", + objs: []string{"MY_POLICY"}, + }, + { + name: "AlterTableStmt.Actions add column with default", + sql: "ALTER TABLE t ADD COLUMN nc INT DEFAULT 3", + lits: []string{"3"}, + tags: map[ast.NodeTag]int{ast.T_ColumnDef: 1}, + }, + { + name: "AlterTableAction.UnsetTags", + sql: "ALTER TABLE t UNSET TAG (my_tag, other_tag)", + objs: []string{"MY_TAG", "OTHER_TAG"}, + }, + { + name: "AlterViewStmt set tag assignments", + sql: "ALTER VIEW v SET TAG (env = 'prod')", + objs: []string{"ENV"}, + }, + { + name: "CopyIntoTableStmt.Options nested file format group", + sql: "COPY INTO t FROM @s FILE_FORMAT = (TYPE = CSV SKIP_HEADER = 1 FIELD_DELIMITER = ',')", + tags: map[ast.NodeTag]int{ast.T_CopyOption: 3}, + }, + { + name: "CreateRoutineStmt argument and return types", + sql: "CREATE FUNCTION multiply(a NUMBER, b NUMBER) RETURNS NUMBER AS 'a * b'", + tags: map[ast.NodeTag]int{ast.T_TypeName: 3}, + }, + { + name: "AlterRoutineStmt.ArgTypes signature", + sql: "ALTER FUNCTION f(NUMBER) RENAME TO g", + tags: map[ast.NodeTag]int{ast.T_TypeName: 1}, + }, + { + name: "CommentStmt.Signature types", + sql: "COMMENT ON FUNCTION f(INT, STRING) IS 'fn'", + tags: map[ast.NodeTag]int{ast.T_TypeName: 2}, + }, + { + name: "CreateStreamStmt.TimeTravel value", + sql: "CREATE STREAM s ON TABLE t AT (TIMESTAMP => TO_TIMESTAMP(40*365*86400))", + lits: []string{"40", "365", "86400"}, + }, + { + name: "CreatePolicyStmt.Args", + sql: "CREATE MASKING POLICY mp AS (val string) RETURNS string -> val", + cols: []string{"VAL"}, + tags: map[ast.NodeTag]int{ast.T_PolicyArg: 1}, + }, + { + name: "CreateSemanticViewStmt.Sections", + sql: "CREATE SEMANTIC VIEW sv TABLES (orders) METRICS (orders.total AS SUM(amount))", + tags: map[ast.NodeTag]int{ast.T_SemanticViewSection: 2}, + }, + { + name: "CreateIntegrationStmt.Triggers (resource monitor)", + sql: "CREATE RESOURCE MONITOR rm WITH CREDIT_QUOTA = 100 TRIGGERS ON 75 PERCENT DO NOTIFY", + tags: map[ast.NodeTag]int{ast.T_ResourceMonitorTrigger: 1}, + }, + { + name: "CreateReplicationGroupStmt.Options", + sql: "CREATE REPLICATION GROUP rg OBJECT_TYPES = DATABASES ALLOWED_ACCOUNTS = org.acct", + tags: map[ast.NodeTag]int{ast.T_GroupOption: 1}, + }, + { + name: "AlterExternalTableStmt.Files literals", + sql: "ALTER EXTERNAL TABLE et ADD FILES ('p/f1.parquet', 'p/f2.parquet')", + lits: []string{"p/f1.parquet", "p/f2.parquet"}, + }, + }) +} + +// TestWalkCoverage_PostOrderBalance verifies the widened traversal still +// delivers balanced pre/post events (every recursion ends with Visit(nil)). +func TestWalkCoverage_PostOrderBalance(t *testing.T) { + sql := `WITH c AS (SELECT ca FROM t) + SELECT CASE WHEN a > 0 THEN b END, + ROW_NUMBER() OVER (PARTITION BY x ORDER BY y ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) + FROM c GROUP BY g ORDER BY o FETCH FIRST 1 ROWS ONLY` + file, err := parser.Parse(sql) + if err != nil { + t.Fatalf("parse: %v", err) + } + pre, post := 0, 0 + v := &balanceVisitor{pre: &pre, post: &post} + ast.Walk(v, file) + if pre != post { + t.Errorf("unbalanced traversal: %d pre-order visits, %d post-order visits", pre, post) + } + if pre == 0 { + t.Error("no nodes visited") + } +} + +type balanceVisitor struct { + pre, post *int +} + +func (v *balanceVisitor) Visit(n ast.Node) ast.Visitor { + if n == nil { + *v.post++ + return nil + } + *v.pre++ + return v +} diff --git a/snowflake/ast/walk_generated.go b/snowflake/ast/walk_generated.go index e6833f19..3b44919f 100644 --- a/snowflake/ast/walk_generated.go +++ b/snowflake/ast/walk_generated.go @@ -13,6 +13,19 @@ func walkChildren(v Visitor, node Node) { if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } + for _, item := range n.UnsetTags { + if item != nil { + Walk(v, item) + } + } if n.PolicyName != nil { Walk(v, n.PolicyName) } @@ -23,6 +36,11 @@ func walkChildren(v Visitor, node Node) { if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } Walk(v, n.Condition) Walk(v, n.ActionBody) case *AlterDatabaseStmt: @@ -32,6 +50,14 @@ func walkChildren(v Visitor, node Node) { if n.NewName != nil { Walk(v, n.NewName) } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } + for _, item := range n.UnsetTags { + if item != nil { + Walk(v, item) + } + } case *AlterDynamicTableStmt: if n.Name != nil { Walk(v, n.Name) @@ -39,11 +65,42 @@ func walkChildren(v Visitor, node Node) { if n.NewName != nil { Walk(v, n.NewName) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } + for _, item := range n.UnsetTags { + if item != nil { + Walk(v, item) + } + } walkNodes(v, n.ClusterBy) case *AlterExternalTableStmt: if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Files { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } + for _, item := range n.UnsetTags { + if item != nil { + Walk(v, item) + } + } case *AlterFileFormatStmt: if n.Name != nil { Walk(v, n.Name) @@ -51,10 +108,38 @@ func walkChildren(v Visitor, node Node) { if n.NewName != nil { Walk(v, n.NewName) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } case *AlterIntegrationStmt: if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } + for _, item := range n.UnsetTags { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Triggers { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Accounts { + if item != nil { + Walk(v, item) + } + } case *AlterMaterializedViewStmt: if n.Name != nil { Walk(v, n.Name) @@ -67,10 +152,28 @@ func walkChildren(v Visitor, node Node) { if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } case *AlterPipeStmt: if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } + for _, item := range n.UnsetTags { + if item != nil { + Walk(v, item) + } + } case *AlterPolicyStmt: if n.Name != nil { Walk(v, n.Name) @@ -79,6 +182,19 @@ func walkChildren(v Visitor, node Node) { Walk(v, n.NewName) } Walk(v, n.Body) + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } + for _, item := range n.UnsetTags { + if item != nil { + Walk(v, item) + } + } case *AlterReplicationGroupStmt: if n.Name != nil { Walk(v, n.Name) @@ -86,6 +202,24 @@ func walkChildren(v Visitor, node Node) { if n.NewName != nil { Walk(v, n.NewName) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } + for _, item := range n.UnsetTags { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Names { + if item != nil { + Walk(v, item) + } + } if n.MoveTo != nil { Walk(v, n.MoveTo) } @@ -96,13 +230,31 @@ func walkChildren(v Visitor, node Node) { if n.NewName != nil { Walk(v, n.NewName) } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } + for _, item := range n.UnsetTags { + if item != nil { + Walk(v, item) + } + } case *AlterRoutineStmt: if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.ArgTypes { + if item != nil { + Walk(v, item) + } + } if n.NewName != nil { Walk(v, n.NewName) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } case *AlterSchemaStmt: if n.Name != nil { Walk(v, n.Name) @@ -110,6 +262,14 @@ func walkChildren(v Visitor, node Node) { if n.NewName != nil { Walk(v, n.NewName) } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } + for _, item := range n.UnsetTags { + if item != nil { + Walk(v, item) + } + } case *AlterSemanticViewStmt: if n.Name != nil { Walk(v, n.Name) @@ -117,6 +277,19 @@ func walkChildren(v Visitor, node Node) { if n.NewName != nil { Walk(v, n.NewName) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } + for _, item := range n.UnsetTags { + if item != nil { + Walk(v, item) + } + } case *AlterSequenceStmt: if n.Name != nil { Walk(v, n.Name) @@ -124,10 +297,34 @@ func walkChildren(v Visitor, node Node) { if n.NewName != nil { Walk(v, n.NewName) } + case *AlterSessionStmt: + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } case *AlterShareStmt: if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Accounts { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } + for _, item := range n.UnsetTags { + if item != nil { + Walk(v, item) + } + } case *AlterStageStmt: if n.Name != nil { Walk(v, n.Name) @@ -135,14 +332,43 @@ func walkChildren(v Visitor, node Node) { if n.NewName != nil { Walk(v, n.NewName) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } + for _, item := range n.UnsetTags { + if item != nil { + Walk(v, item) + } + } case *AlterStreamStmt: if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } + for _, item := range n.UnsetTags { + if item != nil { + Walk(v, item) + } + } case *AlterTableStmt: if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Actions { + walkAlterTableAction(v, item) + } case *AlterTagStmt: if n.Name != nil { Walk(v, n.Name) @@ -150,10 +376,38 @@ func walkChildren(v Visitor, node Node) { if n.NewName != nil { Walk(v, n.NewName) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.MaskingPolicies { + if item != nil { + Walk(v, item) + } + } case *AlterTaskStmt: if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.After { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } + for _, item := range n.UnsetTags { + if item != nil { + Walk(v, item) + } + } Walk(v, n.When) Walk(v, n.Body) case *AlterUserStmt: @@ -163,6 +417,19 @@ func walkChildren(v Visitor, node Node) { if n.NewName != nil { Walk(v, n.NewName) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } + for _, item := range n.UnsetTags { + if item != nil { + Walk(v, item) + } + } case *AlterViewStmt: if n.Name != nil { Walk(v, n.Name) @@ -170,6 +437,14 @@ func walkChildren(v Visitor, node Node) { if n.NewName != nil { Walk(v, n.NewName) } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } + for _, item := range n.UnsetTags { + if item != nil { + Walk(v, item) + } + } if n.PolicyName != nil { Walk(v, n.PolicyName) } @@ -183,6 +458,24 @@ func walkChildren(v Visitor, node Node) { if n.NewName != nil { Walk(v, n.NewName) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } + for _, item := range n.UnsetTags { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Tables { + if item != nil { + Walk(v, item) + } + } case *ArrayLiteralExpr: walkNodes(v, n.Elements) case *BetweenExpr: @@ -201,6 +494,9 @@ func walkChildren(v Visitor, node Node) { walkNodes(v, n.Args) case *CaseExpr: Walk(v, n.Operand) + for _, item := range n.Whens { + walkWhenClause(v, item) + } Walk(v, n.Else) case *CastExpr: Walk(v, n.Expr) @@ -224,6 +520,10 @@ func walkChildren(v Visitor, node Node) { if n.MaskingPolicy != nil { Walk(v, n.MaskingPolicy) } + walkInlineConstraint(v, n.InlineConstraint) + for _, item := range n.Tags { + walkTagAssignment(v, item) + } Walk(v, n.VirtualExpr) case *CommentStmt: if n.Name != nil { @@ -232,6 +532,11 @@ func walkChildren(v Visitor, node Node) { if n.Column != nil { Walk(v, n.Column) } + for _, item := range n.Signature { + if item != nil { + Walk(v, item) + } + } case *ConnectionReplica: if n.Source != nil { Walk(v, n.Source) @@ -245,6 +550,11 @@ func walkChildren(v Visitor, node Node) { } Walk(v, n.FromQuery) Walk(v, n.Partition) + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } case *CopyIntoTableStmt: if n.Target != nil { Walk(v, n.Target) @@ -255,24 +565,56 @@ func walkChildren(v Visitor, node Node) { if n.Transform != nil { Walk(v, n.Transform) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } case *CopyOption: if n.Lit != nil { Walk(v, n.Lit) } + for _, item := range n.Group { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.List { + if item != nil { + Walk(v, item) + } + } case *CreateAccountStmt: if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } case *CreateAlertStmt: if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } Walk(v, n.Condition) Walk(v, n.Action) case *CreateDatabaseStmt: if n.Name != nil { Walk(v, n.Name) } + walkCloneSource(v, n.Clone) + for _, item := range n.Tags { + walkTagAssignment(v, item) + } case *CreateDatasetStmt: if n.Name != nil { Walk(v, n.Name) @@ -281,32 +623,82 @@ func walkChildren(v Visitor, node Node) { if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Columns { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } walkNodes(v, n.ClusterBy) Walk(v, n.ImmutableWhere) Walk(v, n.AsQuery) Walk(v, n.RefreshUsing) + walkCloneSource(v, n.Clone) case *CreateEventTableStmt: if n.Name != nil { Walk(v, n.Name) } walkNodes(v, n.ClusterBy) + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } case *CreateExternalTableStmt: if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Columns { + if item != nil { + Walk(v, item) + } + } Walk(v, n.UsingTemplate) walkNodes(v, n.PartitionBy) if n.Location != nil { Walk(v, n.Location) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } case *CreateFileFormatStmt: if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } case *CreateIntegrationStmt: if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } + for _, item := range n.Triggers { + if item != nil { + Walk(v, item) + } + } if n.Replica != nil { Walk(v, n.Replica) } @@ -314,29 +706,67 @@ func walkChildren(v Visitor, node Node) { if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Columns { + walkViewColumn(v, item) + } + for _, item := range n.ViewCols { + walkViewColumn(v, item) + } walkNodes(v, n.ClusterBy) + for _, item := range n.Tags { + walkTagAssignment(v, item) + } + walkRowAccessPolicy(v, n.RowPolicy) Walk(v, n.Query) case *CreateNetworkRuleStmt: if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } case *CreatePipeStmt: if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } Walk(v, n.Copy) case *CreatePolicyStmt: if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Args { + if item != nil { + Walk(v, item) + } + } if n.Returns != nil { Walk(v, n.Returns) } Walk(v, n.Body) + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } case *CreateReplicationGroupStmt: if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } if n.Replica != nil { Walk(v, n.Replica) } @@ -344,21 +774,52 @@ func walkChildren(v Visitor, node Node) { if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } case *CreateRoutineStmt: if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Args { + walkRoutineArg(v, item) + } if n.ReturnType != nil { Walk(v, n.ReturnType) } + for _, item := range n.ReturnTable { + walkRoutineTableColumn(v, item) + } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } case *CreateSchemaStmt: if n.Name != nil { Walk(v, n.Name) } + walkCloneSource(v, n.Clone) + for _, item := range n.Tags { + walkTagAssignment(v, item) + } case *CreateSemanticViewStmt: if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Sections { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } case *CreateSequenceStmt: if n.Name != nil { Walk(v, n.Name) @@ -367,17 +828,39 @@ func walkChildren(v Visitor, node Node) { if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } case *CreateStageStmt: if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } case *CreateStreamStmt: if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } if n.Source != nil { Walk(v, n.Source) } + walkStreamTimeTravel(v, n.TimeTravel) + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } if n.Clone != nil { Walk(v, n.Clone) } @@ -385,19 +868,56 @@ func walkChildren(v Visitor, node Node) { if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Columns { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Constraints { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Indexes { + if item != nil { + Walk(v, item) + } + } walkNodes(v, n.ClusterBy) + for _, item := range n.Tags { + walkTagAssignment(v, item) + } Walk(v, n.AsSelect) if n.Like != nil { Walk(v, n.Like) } + walkCloneSource(v, n.Clone) case *CreateTagStmt: if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } case *CreateTaskStmt: if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.After { + if item != nil { + Walk(v, item) + } + } Walk(v, n.When) Walk(v, n.Body) if n.Clone != nil { @@ -407,15 +927,41 @@ func walkChildren(v Visitor, node Node) { if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } case *CreateViewStmt: if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Columns { + walkViewColumn(v, item) + } + for _, item := range n.ViewCols { + walkViewColumn(v, item) + } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } + walkRowAccessPolicy(v, n.RowPolicy) Walk(v, n.Query) case *CreateWarehouseStmt: if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } case *DeleteStmt: if n.Target != nil { Walk(v, n.Target) @@ -429,6 +975,11 @@ func walkChildren(v Visitor, node Node) { if n.NameLiteral != nil { Walk(v, n.NameLiteral) } + for _, item := range n.Signature { + if item != nil { + Walk(v, item) + } + } case *DollarRef: if n.Qualifier != nil { Walk(v, n.Qualifier) @@ -458,10 +1009,16 @@ func walkChildren(v Visitor, node Node) { Walk(v, n.DataType) } Walk(v, n.Expr) + walkInlineConstraint(v, n.Constraint) case *File: walkNodes(v, n.Stmts) case *FuncCallExpr: + Walk(v, &n.Name) walkNodes(v, n.Args) + for _, item := range n.OrderBy { + walkOrderItem(v, item) + } + walkWindowSpec(v, n.Over) case *GetStmt: if n.Stage != nil { Walk(v, n.Stage) @@ -469,6 +1026,11 @@ func walkChildren(v Visitor, node Node) { if n.Target != nil { Walk(v, n.Target) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } case *GrantStmt: if n.Role != nil { Walk(v, n.Role) @@ -483,6 +1045,11 @@ func walkChildren(v Visitor, node Node) { if n.Name != nil { Walk(v, n.Name) } + for _, item := range n.Signature { + if item != nil { + Walk(v, item) + } + } if n.ContainerName != nil { Walk(v, n.ContainerName) } @@ -502,6 +1069,9 @@ func walkChildren(v Visitor, node Node) { Walk(v, n.Expr) walkNodes(v, n.Values) case *InsertMultiStmt: + for _, item := range n.Branches { + walkInsertMultiBranch(v, item) + } Walk(v, n.Select) case *InsertStmt: if n.Target != nil { @@ -519,6 +1089,10 @@ func walkChildren(v Visitor, node Node) { Walk(v, n.Right) Walk(v, n.On) Walk(v, n.MatchCondition) + case *JsonLiteralExpr: + for i := range n.Pairs { + walkKeyValuePair(v, &n.Pairs[i]) + } case *LambdaExpr: Walk(v, n.Body) case *LikeExpr: @@ -539,6 +1113,14 @@ func walkChildren(v Visitor, node Node) { Walk(v, n.Expr) case *MatchRecognizeClause: walkNodes(v, n.PartitionBy) + for _, item := range n.OrderBy { + walkOrderItem(v, item) + } + for _, item := range n.Measures { + if item != nil { + Walk(v, item) + } + } if n.RowsPerMatch != nil { Walk(v, n.RowsPerMatch) } @@ -548,12 +1130,20 @@ func walkChildren(v Visitor, node Node) { if n.Pattern != nil { Walk(v, n.Pattern) } + for _, item := range n.Define { + if item != nil { + Walk(v, item) + } + } case *MergeStmt: if n.Target != nil { Walk(v, n.Target) } Walk(v, n.Source) Walk(v, n.On) + for _, item := range n.Whens { + walkMergeWhen(v, item) + } case *OuterJoinExpr: Walk(v, n.Operand) case *ParenExpr: @@ -570,6 +1160,14 @@ func walkChildren(v Visitor, node Node) { } Walk(v, n.DefaultVal) case *PivotInClause: + for _, item := range n.Values { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.OrderBy { + walkOrderItem(v, item) + } Walk(v, n.Subquery) case *PivotValue: Walk(v, n.Value) @@ -584,6 +1182,11 @@ func walkChildren(v Visitor, node Node) { if n.Stage != nil { Walk(v, n.Stage) } + for _, item := range n.Options { + if item != nil { + Walk(v, item) + } + } case *RemoveStmt: if n.Stage != nil { Walk(v, n.Stage) @@ -658,18 +1261,33 @@ func walkChildren(v Visitor, node Node) { Walk(v, n.Cond) walkNodes(v, n.Body) case *SelectStmt: + for _, item := range n.With { + walkCTE(v, item) + } Walk(v, n.Top) + for _, item := range n.Targets { + walkSelectTarget(v, item) + } walkNodes(v, n.From) Walk(v, n.Where) + walkGroupByClause(v, n.GroupBy) Walk(v, n.Having) Walk(v, n.Qualify) Walk(v, n.StartWith) walkNodes(v, n.ConnectBy) + for _, item := range n.OrderBy { + walkOrderItem(v, item) + } Walk(v, n.Limit) Walk(v, n.Offset) + walkFetchClause(v, n.Fetch) case *SetOperationStmt: Walk(v, n.Left) Walk(v, n.Right) + case *SetStmt: + for _, item := range n.Vars { + walkSetVar(v, item) + } case *ShowStmt: if n.ScopeName != nil { Walk(v, n.ScopeName) @@ -687,6 +1305,8 @@ func walkChildren(v Visitor, node Node) { } case *SubqueryExpr: Walk(v, n.Query) + case *TableConstraint: + walkForeignKeyRef(v, n.References) case *TableRef: if n.Name != nil { Walk(v, n.Name) @@ -740,10 +1360,19 @@ func walkChildren(v Visitor, node Node) { if n.Name != nil { Walk(v, n.Name) } + case *UnpivotClause: + for _, item := range n.Columns { + if item != nil { + Walk(v, item) + } + } case *UpdateStmt: if n.Target != nil { Walk(v, n.Target) } + for _, item := range n.Sets { + walkUpdateSet(v, item) + } walkNodes(v, n.From) Walk(v, n.Where) case *UseStmt: @@ -754,3 +1383,321 @@ func walkChildren(v Visitor, node Node) { walkNodeRows(v, n.Rows) } } + +// walkAlterTableAction walks the node-bearing fields of the non-Node helper struct +// AlterTableAction. The helper itself is not visited (it is not a Node); its child +// nodes are walked in field order. +func walkAlterTableAction(v Visitor, n *AlterTableAction) { + if n == nil { + return + } + if n.NewName != nil { + Walk(v, n.NewName) + } + for _, item := range n.Columns { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.ColumnAlters { + walkColumnAlter(v, item) + } + if n.Constraint != nil { + Walk(v, n.Constraint) + } + walkNodes(v, n.ClusterBy) + Walk(v, n.ReclusterWhere) + for _, item := range n.Tags { + walkTagAssignment(v, item) + } + for _, item := range n.UnsetTags { + if item != nil { + Walk(v, item) + } + } + if n.PolicyName != nil { + Walk(v, n.PolicyName) + } + if n.MaskingPolicy != nil { + Walk(v, n.MaskingPolicy) + } +} + +// walkCTE walks the node-bearing fields of the non-Node helper struct +// CTE. The helper itself is not visited (it is not a Node); its child +// nodes are walked in field order. +func walkCTE(v Visitor, n *CTE) { + if n == nil { + return + } + Walk(v, n.Query) +} + +// walkCloneSource walks the node-bearing fields of the non-Node helper struct +// CloneSource. The helper itself is not visited (it is not a Node); its child +// nodes are walked in field order. +func walkCloneSource(v Visitor, n *CloneSource) { + if n == nil { + return + } + if n.Source != nil { + Walk(v, n.Source) + } + Walk(v, n.Value) +} + +// walkColumnAlter walks the node-bearing fields of the non-Node helper struct +// ColumnAlter. The helper itself is not visited (it is not a Node); its child +// nodes are walked in field order. +func walkColumnAlter(v Visitor, n *ColumnAlter) { + if n == nil { + return + } + if n.DataType != nil { + Walk(v, n.DataType) + } + Walk(v, n.DefaultExpr) +} + +// walkFetchClause walks the node-bearing fields of the non-Node helper struct +// FetchClause. The helper itself is not visited (it is not a Node); its child +// nodes are walked in field order. +func walkFetchClause(v Visitor, n *FetchClause) { + if n == nil { + return + } + Walk(v, n.Count) +} + +// walkForeignKeyRef walks the node-bearing fields of the non-Node helper struct +// ForeignKeyRef. The helper itself is not visited (it is not a Node); its child +// nodes are walked in field order. +func walkForeignKeyRef(v Visitor, n *ForeignKeyRef) { + if n == nil { + return + } + if n.Table != nil { + Walk(v, n.Table) + } +} + +// walkGroupByClause walks the node-bearing fields of the non-Node helper struct +// GroupByClause. The helper itself is not visited (it is not a Node); its child +// nodes are walked in field order. +func walkGroupByClause(v Visitor, n *GroupByClause) { + if n == nil { + return + } + walkNodes(v, n.Items) +} + +// walkInlineConstraint walks the node-bearing fields of the non-Node helper struct +// InlineConstraint. The helper itself is not visited (it is not a Node); its child +// nodes are walked in field order. +func walkInlineConstraint(v Visitor, n *InlineConstraint) { + if n == nil { + return + } + walkForeignKeyRef(v, n.References) +} + +// walkInsertMultiBranch walks the node-bearing fields of the non-Node helper struct +// InsertMultiBranch. The helper itself is not visited (it is not a Node); its child +// nodes are walked in field order. +func walkInsertMultiBranch(v Visitor, n *InsertMultiBranch) { + if n == nil { + return + } + Walk(v, n.When) + if n.Target != nil { + Walk(v, n.Target) + } + walkNodes(v, n.Values) +} + +// walkKeyValuePair walks the node-bearing fields of the non-Node helper struct +// KeyValuePair. The helper itself is not visited (it is not a Node); its child +// nodes are walked in field order. +func walkKeyValuePair(v Visitor, n *KeyValuePair) { + if n == nil { + return + } + Walk(v, n.Value) +} + +// walkMergeWhen walks the node-bearing fields of the non-Node helper struct +// MergeWhen. The helper itself is not visited (it is not a Node); its child +// nodes are walked in field order. +func walkMergeWhen(v Visitor, n *MergeWhen) { + if n == nil { + return + } + Walk(v, n.AndCond) + for _, item := range n.Sets { + walkUpdateSet(v, item) + } + walkNodes(v, n.InsertVals) +} + +// walkOrderItem walks the node-bearing fields of the non-Node helper struct +// OrderItem. The helper itself is not visited (it is not a Node); its child +// nodes are walked in field order. +func walkOrderItem(v Visitor, n *OrderItem) { + if n == nil { + return + } + Walk(v, n.Expr) +} + +// walkRoutineArg walks the node-bearing fields of the non-Node helper struct +// RoutineArg. The helper itself is not visited (it is not a Node); its child +// nodes are walked in field order. +func walkRoutineArg(v Visitor, n *RoutineArg) { + if n == nil { + return + } + if n.Type != nil { + Walk(v, n.Type) + } + Walk(v, n.Default) +} + +// walkRoutineTableColumn walks the node-bearing fields of the non-Node helper struct +// RoutineTableColumn. The helper itself is not visited (it is not a Node); its child +// nodes are walked in field order. +func walkRoutineTableColumn(v Visitor, n *RoutineTableColumn) { + if n == nil { + return + } + if n.Type != nil { + Walk(v, n.Type) + } +} + +// walkRowAccessPolicy walks the node-bearing fields of the non-Node helper struct +// RowAccessPolicy. The helper itself is not visited (it is not a Node); its child +// nodes are walked in field order. +func walkRowAccessPolicy(v Visitor, n *RowAccessPolicy) { + if n == nil { + return + } + if n.PolicyName != nil { + Walk(v, n.PolicyName) + } +} + +// walkSelectTarget walks the node-bearing fields of the non-Node helper struct +// SelectTarget. The helper itself is not visited (it is not a Node); its child +// nodes are walked in field order. +func walkSelectTarget(v Visitor, n *SelectTarget) { + if n == nil { + return + } + Walk(v, n.Expr) +} + +// walkSetVar walks the node-bearing fields of the non-Node helper struct +// SetVar. The helper itself is not visited (it is not a Node); its child +// nodes are walked in field order. +func walkSetVar(v Visitor, n *SetVar) { + if n == nil { + return + } + Walk(v, n.Value) +} + +// walkStreamTimeTravel walks the node-bearing fields of the non-Node helper struct +// StreamTimeTravel. The helper itself is not visited (it is not a Node); its child +// nodes are walked in field order. +func walkStreamTimeTravel(v Visitor, n *StreamTimeTravel) { + if n == nil { + return + } + Walk(v, n.Value) +} + +// walkTagAssignment walks the node-bearing fields of the non-Node helper struct +// TagAssignment. The helper itself is not visited (it is not a Node); its child +// nodes are walked in field order. +func walkTagAssignment(v Visitor, n *TagAssignment) { + if n == nil { + return + } + if n.Name != nil { + Walk(v, n.Name) + } +} + +// walkUpdateSet walks the node-bearing fields of the non-Node helper struct +// UpdateSet. The helper itself is not visited (it is not a Node); its child +// nodes are walked in field order. +func walkUpdateSet(v Visitor, n *UpdateSet) { + if n == nil { + return + } + if n.Column != nil { + Walk(v, n.Column) + } + Walk(v, n.Value) +} + +// walkViewColumn walks the node-bearing fields of the non-Node helper struct +// ViewColumn. The helper itself is not visited (it is not a Node); its child +// nodes are walked in field order. +func walkViewColumn(v Visitor, n *ViewColumn) { + if n == nil { + return + } + if n.MaskingPolicy != nil { + Walk(v, n.MaskingPolicy) + } + for _, item := range n.Tags { + walkTagAssignment(v, item) + } +} + +// walkWhenClause walks the node-bearing fields of the non-Node helper struct +// WhenClause. The helper itself is not visited (it is not a Node); its child +// nodes are walked in field order. +func walkWhenClause(v Visitor, n *WhenClause) { + if n == nil { + return + } + Walk(v, n.Cond) + Walk(v, n.Result) +} + +// walkWindowBound walks the node-bearing fields of the non-Node helper struct +// WindowBound. The helper itself is not visited (it is not a Node); its child +// nodes are walked in field order. +func walkWindowBound(v Visitor, n *WindowBound) { + if n == nil { + return + } + Walk(v, n.Offset) +} + +// walkWindowFrame walks the node-bearing fields of the non-Node helper struct +// WindowFrame. The helper itself is not visited (it is not a Node); its child +// nodes are walked in field order. +func walkWindowFrame(v Visitor, n *WindowFrame) { + if n == nil { + return + } + walkWindowBound(v, &n.Start) + walkWindowBound(v, &n.End) +} + +// walkWindowSpec walks the node-bearing fields of the non-Node helper struct +// WindowSpec. The helper itself is not visited (it is not a Node); its child +// nodes are walked in field order. +func walkWindowSpec(v Visitor, n *WindowSpec) { + if n == nil { + return + } + walkNodes(v, n.PartitionBy) + for _, item := range n.OrderBy { + walkOrderItem(v, item) + } + walkWindowFrame(v, n.Frame) +}