Skip to content

Commit 52acd0e

Browse files
committed
rule: fix unability to check conditions of natural joins
Signed-off-by: Miguel Molina <[email protected]>
1 parent cf34387 commit 52acd0e

File tree

2 files changed

+116
-1
lines changed

2 files changed

+116
-1
lines changed

internal/rule/squashjoins.go

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1009,10 +1009,14 @@ func isJoinCondSquashable(join *plan.InnerJoin) bool {
10091009
leftTables := findLeafTables(join.Left)
10101010
rightTables := findLeafTables(join.Right)
10111011

1012+
var squashedTables []string
10121013
if len(rightTables) == 1 {
1014+
squashedTables = findSquashedTables(join.Left)
10131015
leftTables, rightTables = rightTables, leftTables
10141016
} else if len(leftTables) != 1 {
10151017
return false
1018+
} else {
1019+
squashedTables = findSquashedTables(join.Right)
10161020
}
10171021

10181022
lt := leftTables[0]
@@ -1021,8 +1025,33 @@ func isJoinCondSquashable(join *plan.InnerJoin) bool {
10211025
continue
10221026
}
10231027

1028+
var cond = join.Cond
1029+
// if the right table is squashed, we might need to rewrite some column
1030+
// tables in order to find the condition, since natural joins deduplicate
1031+
// columns with the same name.
1032+
if stringInSlice(squashedTables, rt) {
1033+
c, err := join.Cond.TransformUp(func(e sql.Expression) (sql.Expression, error) {
1034+
gf, ok := e.(*expression.GetField)
1035+
if ok && gf.Table() != rt && gf.Table() != lt {
1036+
if tableHasColumn(rt, gf.Name()) {
1037+
return expression.NewGetFieldWithTable(
1038+
gf.Index(),
1039+
gf.Type(),
1040+
rt,
1041+
gf.Name(),
1042+
gf.IsNullable(),
1043+
), nil
1044+
}
1045+
}
1046+
return e, nil
1047+
})
1048+
if err == nil {
1049+
cond = c
1050+
}
1051+
}
1052+
10241053
t1, t2 := orderedTablePair(lt, rt)
1025-
if hasChainableJoinCondition(join.Cond, t1, t2) {
1054+
if hasChainableJoinCondition(cond, t1, t2) {
10261055
return true
10271056
}
10281057
}
@@ -1064,6 +1093,20 @@ func findLeafTables(n sql.Node) []string {
10641093
return tables
10651094
}
10661095

1096+
func findSquashedTables(n sql.Node) []string {
1097+
var tables []string
1098+
plan.Inspect(n, func(n sql.Node) bool {
1099+
switch n := n.(type) {
1100+
case *joinedTables:
1101+
tables = orderedTableNames(n.tables)
1102+
return false
1103+
default:
1104+
return true
1105+
}
1106+
})
1107+
return tables
1108+
}
1109+
10671110
func exprToFilters(expr sql.Expression) (filters []sql.Expression) {
10681111
if expr, ok := expr.(*expression.And); ok {
10691112
return append(exprToFilters(expr.Left), exprToFilters(expr.Right)...)
@@ -1548,3 +1591,36 @@ func filterDiff(a, b []sql.Expression) []sql.Expression {
15481591

15491592
return result
15501593
}
1594+
1595+
func tableHasColumn(t, col string) bool {
1596+
return tableSchema(t).Contains(col, t)
1597+
}
1598+
1599+
func tableSchema(t string) sql.Schema {
1600+
switch t {
1601+
case gitbase.RepositoriesTableName:
1602+
return gitbase.RepositoriesSchema
1603+
case gitbase.ReferencesTableName:
1604+
return gitbase.RefsSchema
1605+
case gitbase.RemotesTableName:
1606+
return gitbase.RemotesSchema
1607+
case gitbase.RefCommitsTableName:
1608+
return gitbase.RefCommitsSchema
1609+
case gitbase.CommitsTableName:
1610+
return gitbase.CommitsSchema
1611+
case gitbase.CommitTreesTableName:
1612+
return gitbase.CommitTreesSchema
1613+
case gitbase.CommitBlobsTableName:
1614+
return gitbase.CommitBlobsSchema
1615+
case gitbase.CommitFilesTableName:
1616+
return gitbase.CommitFilesSchema
1617+
case gitbase.TreeEntriesTableName:
1618+
return gitbase.TreeEntriesSchema
1619+
case gitbase.BlobsTableName:
1620+
return gitbase.BlobsSchema
1621+
case gitbase.FilesTableName:
1622+
return gitbase.FilesSchema
1623+
default:
1624+
return nil
1625+
}
1626+
}

internal/rule/squashjoins_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,45 @@ func TestAnalyzeSquashJoinsExchange(t *testing.T) {
4747
require.True(ok)
4848
}
4949

50+
func TestAnalyzeSquashNaturalJoins(t *testing.T) {
51+
require := require.New(t)
52+
53+
catalog := sql.NewCatalog()
54+
catalog.AddDatabase(gitbase.NewDatabase("foo"))
55+
a := analyzer.NewBuilder(catalog).
56+
WithParallelism(2).
57+
AddPostAnalyzeRule(SquashJoinsRule, SquashJoins).
58+
Build()
59+
a.Batches[len(a.Batches)-1].Rules = a.Batches[len(a.Batches)-1].Rules[1:]
60+
a.CurrentDatabase = "foo"
61+
ctx := sql.NewEmptyContext()
62+
63+
node, err := parse.Parse(ctx, `SELECT * FROM refs
64+
NATURAL JOIN commits
65+
NATURAL JOIN commit_files
66+
NATURAL JOIN files`)
67+
require.NoError(err)
68+
69+
result, err := a.Analyze(ctx, node)
70+
require.NoError(err)
71+
72+
exchange, ok := result.(*plan.Exchange)
73+
require.True(ok)
74+
require.Equal(2, exchange.Parallelism)
75+
76+
project, ok := exchange.Child.(*plan.Project)
77+
require.True(ok)
78+
79+
filter, ok := project.Child.(*plan.Filter)
80+
require.True(ok)
81+
82+
rt, ok := filter.Child.(*plan.ResolvedTable)
83+
require.True(ok)
84+
85+
_, ok = rt.Table.(*gitbase.SquashedTable)
86+
require.True(ok)
87+
}
88+
5089
func TestSquashJoins(t *testing.T) {
5190
require := require.New(t)
5291

0 commit comments

Comments
 (0)