Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 149 additions & 14 deletions snowflake/analysis/query_span.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,17 @@ type tableEntry struct {
// For a base table with unknown schema, we synthesize a * source column
// rather than enumerating specific columns.
isUnknown bool // true when we have no schema information
// opaque marks a PIVOT/UNPIVOT-derived relation. The output columns of a
// pivoted relation are computed from its source relation but carry names
// that do not exist on it (pivot value columns / the UNPIVOT value+name
// columns), and without catalog information the projection cannot be
// enumerated. To stay masking-sound, every column resolved through such
// an entry — star or named, qualified or bare — attributes to these
// whole-relation ("*") sources instead of fabricating per-column
// attributions. Over-attribution is deliberate: claiming base.* for a
// pivot output is conservative, while claiming base.Q1 for a column that
// actually carries agg(base.amount) would let a masking consumer leak.
opaque []*SourceColumn
}

// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -176,6 +187,17 @@ func extractSelectStmt(s *ast.SelectStmt, scope *queryScope) *QuerySpan {
// Also include "table-level" sources for every FROM entry even if no
// columns are explicitly referenced (e.g. SELECT 1 FROM t still "touches" t).
for _, entry := range fromEntries {
if len(entry.opaque) > 0 {
// Pivoted relation: the alias names a derived relation, not a
// real table — record the underlying whole-relation sources.
for _, sc := range entry.opaque {
k := toSourceKey(sc)
if _, exists := sourcesMap[k]; !exists {
sourcesMap[k] = sc
}
}
continue
}
sc := &SourceColumn{Table: entry.alias, Column: "*"}
k := toSourceKey(sc)
if _, exists := sourcesMap[k]; !exists {
Expand Down Expand Up @@ -281,6 +303,13 @@ func resolveFromItem(item ast.Node, scope *queryScope) []tableEntry {

// resolveTableRef resolves a single TableRef to a tableEntry.
func resolveTableRef(ref *ast.TableRef, scope *queryScope) tableEntry {
// PIVOT/UNPIVOT (possibly chained via Nested): the relation's output
// columns are derived from the source relation under names the source
// does not have — resolve as an opaque-derived entry.
if ref.Pivot != nil || ref.Unpivot != nil || ref.Nested != nil {
return resolvePivotedRef(ref, scope)
}

// Subquery: (SELECT ...) AS alias
if ref.Subquery != nil {
subSpan := extractSpan(ref.Subquery, scope)
Expand Down Expand Up @@ -351,6 +380,70 @@ func resolveJoin(join *ast.JoinExpr, scope *queryScope) []tableEntry {
return append(left, right...)
}

// resolvePivotedRef resolves a TableRef that carries PIVOT/UNPIVOT clauses
// (or a Nested chain of them) to an opaque-derived tableEntry: the entry is
// addressed by the clause's trailing alias (PIVOT(...) AS p) and every column
// resolved through it attributes to the underlying relation's whole-relation
// sources. See tableEntry.opaque for the soundness rationale.
func resolvePivotedRef(ref *ast.TableRef, scope *queryScope) tableEntry {
// Resolve the underlying source relation.
var base tableEntry
if ref.Nested != nil {
base = resolveTableRef(ref.Nested, scope)
} else {
stripped := *ref
stripped.Pivot = nil
stripped.Unpivot = nil
base = resolveTableRef(&stripped, scope)
}

// The pivoted relation is addressed by the clause's trailing alias when
// present (the parser stores it on the clause, not on the TableRef).
alias := ""
if ref.Pivot != nil {
alias = ref.Pivot.Alias.Normalize()
}
if alias == "" && ref.Unpivot != nil {
alias = ref.Unpivot.Alias.Normalize()
}
if alias == "" {
alias = ref.Alias.Normalize()
}
if alias == "" {
alias = base.alias
}

return tableEntry{
alias: alias,
isUnknown: true,
opaque: collapseEntryToOpaque(base),
}
}

// collapseEntryToOpaque reduces a resolved entry to its whole-relation ("*")
// source set, used to attribute every column of a pivoted relation.
func collapseEntryToOpaque(e tableEntry) []*SourceColumn {
if len(e.opaque) > 0 {
// Already opaque (chained PIVOT): keep the original base sources.
return e.opaque
}
var out []*SourceColumn
named := false
for _, c := range e.columns {
if c.Column == "*" {
out = append(out, c)
} else {
named = true
}
}
if named || len(out) == 0 {
// Subquery/CTE entries expose named columns; collapse them to a
// single whole-relation source under the entry's alias.
out = append(out, &SourceColumn{Table: e.alias, Column: "*"})
}
return out
}

// ---------------------------------------------------------------------------
// SelectTarget resolution
// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -394,6 +487,12 @@ func resolveStarTarget(target *ast.SelectTarget, fromEntries []tableEntry) *Resu

// expandEntryStar expands a tableEntry for * selection, filtering excluded columns.
func expandEntryStar(entry tableEntry, excluded map[string]bool) []*SourceColumn {
if len(entry.opaque) > 0 {
// Pivoted relation: output column names are unrelated to the base
// columns, so EXCLUDE filtering cannot apply — return the
// whole-relation sources.
return entry.opaque
}
if entry.isUnknown || len(entry.columns) == 0 {
// Unknown table: emit a * pseudo-source.
return []*SourceColumn{{Table: entry.alias, Column: "*"}}
Expand Down Expand Up @@ -475,11 +574,12 @@ func collectColumnRefs(expr ast.Node, fromEntries []tableEntry) []*SourceColumn
}
switch n := node.(type) {
case *ast.ColumnRef:
sc := resolveColumnRef(n, fromEntries)
k := toSourceKey(sc)
if !seen[k] {
seen[k] = true
refs = append(refs, sc)
for _, sc := range resolveColumnRef(n, fromEntries) {
k := toSourceKey(sc)
if !seen[k] {
seen[k] = true
refs = append(refs, sc)
}
}
return false // don't recurse into ColumnRef parts
case *ast.StarExpr:
Expand All @@ -503,24 +603,59 @@ func collectColumnRefs(expr ast.Node, fromEntries []tableEntry) []*SourceColumn
return refs
}

// resolveColumnRef resolves a ColumnRef to a SourceColumn, looking up the
// resolveColumnRef resolves a ColumnRef to its SourceColumns, looking up the
// table qualifier from the FROM entries when only a column name is provided.
func resolveColumnRef(ref *ast.ColumnRef, fromEntries []tableEntry) *SourceColumn {
// A reference into a pivoted (opaque) relation resolves to that relation's
// whole-relation sources rather than a fabricated per-column attribution.
func resolveColumnRef(ref *ast.ColumnRef, fromEntries []tableEntry) []*SourceColumn {
parts := ref.Parts

// Qualified ref whose qualifier names a pivoted relation: attribute to
// the relation's whole-relation source set (e.g. p.Q1 with
// `t PIVOT(...) AS p` reads derived data, not a base column "Q1").
if len(parts) >= 2 {
qual := parts[0].Normalize()
for _, entry := range fromEntries {
if len(entry.opaque) > 0 && strings.EqualFold(entry.alias, qual) {
return entry.opaque
}
}
}

switch len(parts) {
case 1:
col := parts[0].Normalize()
// Try to find which table this column comes from.
table := resolveColumnToTable(col, fromEntries)
return &SourceColumn{Table: table, Column: col}
// A bare column in a scope containing pivoted relations may name one
// of their derived output columns: include those relations'
// whole-relation sources (conservative over-attribution).
var out []*SourceColumn
var plain []tableEntry
for _, entry := range fromEntries {
if len(entry.opaque) > 0 {
out = append(out, entry.opaque...)
} else {
plain = append(plain, entry)
}
}
if len(out) == 0 {
// No pivoted relations: existing resolution.
table := resolveColumnToTable(col, fromEntries)
return []*SourceColumn{{Table: table, Column: col}}
}
if len(plain) > 0 {
// The column may equally come from a non-pivoted relation.
table := resolveColumnToTable(col, plain)
out = append(out, &SourceColumn{Table: table, Column: col})
}
return out
case 2:
return &SourceColumn{Table: parts[0].Normalize(), Column: parts[1].Normalize()}
return []*SourceColumn{{Table: parts[0].Normalize(), Column: parts[1].Normalize()}}
case 3:
return &SourceColumn{Schema: parts[0].Normalize(), Table: parts[1].Normalize(), Column: parts[2].Normalize()}
return []*SourceColumn{{Schema: parts[0].Normalize(), Table: parts[1].Normalize(), Column: parts[2].Normalize()}}
case 4:
return &SourceColumn{Database: parts[0].Normalize(), Schema: parts[1].Normalize(), Table: parts[2].Normalize(), Column: parts[3].Normalize()}
return []*SourceColumn{{Database: parts[0].Normalize(), Schema: parts[1].Normalize(), Table: parts[2].Normalize(), Column: parts[3].Normalize()}}
}
return &SourceColumn{Column: strings.Join(identPartsToStrings(parts), ".")}
return []*SourceColumn{{Column: strings.Join(identPartsToStrings(parts), ".")}}
}

// resolveColumnToTable looks up which FROM entry contains the given column name.
Expand Down
121 changes: 121 additions & 0 deletions snowflake/analysis/query_span_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -646,3 +646,124 @@ func TestQuerySpan_WithinGroupColumns(t *testing.T) {
}
}
}

// ---------------------------------------------------------------------------
// PIVOT / UNPIVOT: opaque-derived relations (masking soundness)
// ---------------------------------------------------------------------------
//
// The output columns of a pivoted relation are computed from the source
// relation under names that do not exist on it. Without catalog information
// the projection cannot be enumerated, so every column resolved through a
// pivoted relation must attribute to the source's whole-relation ("*")
// sources. Fabricating per-column attributions (e.g. T."2023_Q1" for a pivot
// value column that actually carries SUM(T.AMOUNT)) would under-attribute and
// let a masking consumer leak.

func TestQuerySpan_PivotStar(t *testing.T) {
span := mustExtract(t, "SELECT * FROM t PIVOT(SUM(a) FOR m IN ('JAN', 'FEB')) AS p")

if len(span.Results) != 1 {
t.Fatalf("Results: got %d, want 1", len(span.Results))
}
keys := resultSourceKeys(span, 0)
want := []string{"..T.*"}
if len(keys) != 1 || keys[0] != want[0] {
t.Fatalf("result sources: got %v, want %v", keys, want)
}
// The pivot alias P must not surface as a source table.
for _, k := range sourceKeys(span) {
if strings.HasPrefix(k, "..P.") {
t.Errorf("pivot alias leaked into sources: %v", sourceKeys(span))
}
}
}

func TestQuerySpan_PivotValueColumnNotFabricated(t *testing.T) {
span := mustExtract(t, `SELECT "2023_Q1" FROM quarterly_sales PIVOT(SUM(amount) FOR quarter IN ('2023_Q1', '2023_Q2')) AS p`)

keys := resultSourceKeys(span, 0)
want := []string{"..QUARTERLY_SALES.*"}
if len(keys) != 1 || keys[0] != want[0] {
t.Fatalf("result sources: got %v, want %v", keys, want)
}
if hasSrcKey(span, "..QUARTERLY_SALES.2023_Q1") {
t.Error("fabricated base-column attribution for a pivot value column")
}
}

func TestQuerySpan_PivotQualifiedRef(t *testing.T) {
span := mustExtract(t, "SELECT p.q1 FROM db1.s1.t PIVOT(SUM(a) FOR m IN ('JAN')) AS p")

keys := resultSourceKeys(span, 0)
want := []string{"DB1.S1.T.*"}
if len(keys) != 1 || keys[0] != want[0] {
t.Fatalf("result sources: got %v, want %v", keys, want)
}
if hasSrcKey(span, "..P.Q1") {
t.Error("qualified ref through pivot alias must not fabricate table P")
}
}

func TestQuerySpan_PivotQualifiedStar(t *testing.T) {
span := mustExtract(t, "SELECT p.* FROM t PIVOT(SUM(a) FOR m IN ('JAN')) AS p")

keys := resultSourceKeys(span, 0)
want := []string{"..T.*"}
if len(keys) != 1 || keys[0] != want[0] {
t.Fatalf("result sources: got %v, want %v", keys, want)
}
}

func TestQuerySpan_UnpivotValueColumn(t *testing.T) {
span := mustExtract(t, "SELECT sales FROM monthly_sales UNPIVOT (sales FOR month IN (jan, feb)) unpvt")

keys := resultSourceKeys(span, 0)
want := []string{"..MONTHLY_SALES.*"}
if len(keys) != 1 || keys[0] != want[0] {
t.Fatalf("result sources: got %v, want %v", keys, want)
}
if hasSrcKey(span, "..MONTHLY_SALES.SALES") {
t.Error("UNPIVOT value column attributed to a nonexistent base column")
}
if hasSrcKey(span, "..UNPVT.SALES") {
t.Error("UNPIVOT alias fabricated as a source table")
}
}

func TestQuerySpan_PivotChainedOverSubquery(t *testing.T) {
span := mustExtract(t, `SELECT q1 FROM (SELECT a, b FROM real_table) PIVOT (SUM(a) FOR b IN ('x')) PIVOT (MAX(a) FOR b IN ('y')) AS pp`)

// The chain collapses to the subquery's whole-relation source (the
// package addresses unaliased subqueries as "_subquery").
keys := resultSourceKeys(span, 0)
want := []string{".._subquery.*"}
if len(keys) != 1 || keys[0] != want[0] {
t.Fatalf("result sources: got %v, want %v", keys, want)
}
}

func TestQuerySpan_PivotJoinMixedAttribution(t *testing.T) {
// A bare column in a scope with both a pivoted relation and a plain table
// may come from either: both must be attributed (conservative).
span := mustExtract(t, "SELECT x FROM t PIVOT(SUM(a) FOR m IN ('J')) AS p, u")

keys := resultSourceKeys(span, 0)
if len(keys) != 2 {
t.Fatalf("result sources: got %v, want T.* plus U.X", keys)
}
if keys[0] != "..T.*" || keys[1] != "..U.X" {
t.Errorf("result sources: got %v, want [..T.* ..U.X]", keys)
}
}

func TestQuerySpan_PivotOverCTE(t *testing.T) {
span := mustExtract(t, "WITH c AS (SELECT a FROM real_table) SELECT * FROM c PIVOT(SUM(a) FOR m IN ('J')) AS p")

// The CTE's columns collapse to a whole-relation source under the CTE
// name (consistent with how non-pivoted CTE references are reported).
keys := resultSourceKeys(span, 0)
want := []string{"..C.*"}
if len(keys) != 1 || keys[0] != want[0] {
t.Fatalf("result sources: got %v, want %v", keys, want)
}
}
10 changes: 9 additions & 1 deletion snowflake/ast/parsenodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,8 @@ type StarRename struct {
// - Table: Name is set, others nil
// - Subquery: Subquery is set, Name is nil
// - Table function: FuncCall is set, Name is nil
// - Chained PIVOT/UNPIVOT: Nested is set, carrying the previous
// pivoted source (see Nested below)
// - Any of the above can have Lateral = true
type TableRef struct {
Name *ObjectName // table name; nil for subquery/func sources
Expand All @@ -834,7 +836,13 @@ type TableRef struct {
Subquery Node // (SELECT ...) or (VALUES ...) in FROM; nil for table refs
FuncCall *FuncCallExpr // TABLE(func(...)); nil for table refs
DollarN *DollarRef // $N result-set table ref (RESULT_SCAN of a prior result); nil otherwise
Lateral bool // LATERAL prefix
// Nested is the source relation of a chained PIVOT/UNPIVOT: in
// `src PIVOT(...) PIVOT(...)` each clause after the first applies to the
// result of the previous one, so the parser re-roots the TableRef with
// the prior pivoted ref here. When Nested is set, Name/Subquery/FuncCall/
// DollarN are nil and exactly one of Pivot/Unpivot is set.
Nested *TableRef
Lateral bool // LATERAL prefix
// Table-attached clauses (Snowflake). Any combination may appear; the
// documented source order is AT/BEFORE → CHANGES → MATCH_RECOGNIZE →
// PIVOT/UNPIVOT → alias → SAMPLE.
Expand Down
15 changes: 15 additions & 0 deletions snowflake/ast/walk_coverage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,21 @@ func TestWalkCoverage_FromClauses(t *testing.T) {
sql: "SELECT * FROM monthly_sales UNPIVOT (sales FOR month IN (jan, feb))",
tags: map[ast.NodeTag]int{ast.T_UnpivotColumn: 2},
},
{
name: "TableRef.Nested chained PIVOT reaches both clauses",
sql: `SELECT * FROM (SELECT * FROM quarterly_sales)
PIVOT (SUM(amount) FOR quarter_amount IN ('2023_Q1' AS q1, '2023_Q2'))
PIVOT (MAX(discount_percent) FOR quarter_discount IN ('2023_Q3'))`,
cols: []string{"AMOUNT", "QUARTER_AMOUNT", "DISCOUNT_PERCENT", "QUARTER_DISCOUNT"},
lits: []string{"2023_Q1", "2023_Q2", "2023_Q3"},
tags: map[ast.NodeTag]int{ast.T_PivotClause: 2, ast.T_PivotValue: 3},
},
{
name: "TableRef.Nested UNPIVOT then PIVOT",
sql: "SELECT * FROM t UNPIVOT (v FOR n IN (ca, cb)) PIVOT (SUM(v) FOR n IN ('x'))",
cols: []string{"V", "N"},
tags: map[ast.NodeTag]int{ast.T_UnpivotClause: 1, ast.T_PivotClause: 1, ast.T_UnpivotColumn: 2},
},
{
name: "MatchRecognizeClause OrderBy/Measures/Define",
sql: `SELECT * FROM stock_price_history MATCH_RECOGNIZE(
Expand Down
Loading
Loading