Skip to content

Commit fda2ff7

Browse files
committed
internal/rule: fix squash joins when there are exchanges
Signed-off-by: Miguel Molina <[email protected]>
1 parent a0ad599 commit fda2ff7

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

internal/rule/squashjoins.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,17 +182,21 @@ func joinTables(join *plan.InnerJoin) (*joinedTables, error) {
182182
func rearrange(join *plan.InnerJoin, squashedTable *joinedTables) sql.Node {
183183
var projections []sql.Expression
184184
var filters []sql.Expression
185+
var parallelism int
185186
plan.Inspect(join, func(node sql.Node) bool {
186187
switch node := node.(type) {
187188
case *plan.Project:
188189
projections = append(projections, node.Projections...)
189190
case *plan.Filter:
190191
filters = append(filters, node.Expression)
192+
case *plan.Exchange:
193+
parallelism = node.Parallelism
191194
}
192195
return true
193196
})
194197

195198
var node sql.Node = squashedTable
199+
196200
if len(filters) > 0 {
197201
node = plan.NewFilter(expression.JoinAnd(filters...), node)
198202
}
@@ -201,6 +205,10 @@ func rearrange(join *plan.InnerJoin, squashedTable *joinedTables) sql.Node {
201205
node = plan.NewProject(projections, node)
202206
}
203207

208+
if parallelism > 1 {
209+
node = plan.NewExchange(parallelism, node)
210+
}
211+
204212
return node
205213
}
206214

@@ -865,7 +873,7 @@ func isJoinLeafSquashable(node sql.Node) bool {
865873
return false
866874
}
867875
hasSquashableTables = true
868-
case *plan.Project, *plan.Filter, *plan.TableAlias, nil:
876+
case *plan.Project, *plan.Filter, *plan.TableAlias, *plan.Exchange, nil:
869877
default:
870878
hasUnsquashableNodes = true
871879
return false

internal/rule/squashjoins_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,39 @@ import (
1010
"gopkg.in/src-d/go-mysql-server.v0/sql"
1111
"gopkg.in/src-d/go-mysql-server.v0/sql/analyzer"
1212
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
13+
"gopkg.in/src-d/go-mysql-server.v0/sql/parse"
1314
"gopkg.in/src-d/go-mysql-server.v0/sql/plan"
1415
)
1516

17+
func TestAnalyzeSquashJoinsExchange(t *testing.T) {
18+
require := require.New(t)
19+
20+
catalog := sql.NewCatalog()
21+
catalog.AddDatabase(gitbase.NewDatabase("foo"))
22+
a := analyzer.NewBuilder(catalog).
23+
WithParallelism(2).
24+
AddPostAnalyzeRule(SquashJoinsRule, SquashJoins).
25+
Build()
26+
a.CurrentDatabase = "foo"
27+
ctx := sql.NewEmptyContext()
28+
29+
node, err := parse.Parse(ctx, `SELECT * FROM ref_commits NATURAL JOIN commits`)
30+
require.NoError(err)
31+
32+
result, err := a.Analyze(ctx, node)
33+
require.NoError(err)
34+
35+
project, ok := result.(*plan.Project)
36+
require.True(ok)
37+
38+
exchange, ok := project.Child.(*plan.Exchange)
39+
require.True(ok)
40+
require.Equal(2, exchange.Parallelism)
41+
42+
_, ok = exchange.Child.(*gitbase.SquashedTable)
43+
require.True(ok)
44+
}
45+
1646
func TestSquashJoins(t *testing.T) {
1747
require := require.New(t)
1848

0 commit comments

Comments
 (0)