Skip to content

Commit 49a0a70

Browse files
authored
Merge pull request #515 from erizocosmico/fix/squash-unsquashable
rule: do not error when parts of joins are not squashable
2 parents e7bdc6e + f78aae4 commit 49a0a70

File tree

3 files changed

+428
-95
lines changed

3 files changed

+428
-95
lines changed

integration_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,15 @@ func TestSquashCorrectness(t *testing.T) {
413413
) t
414414
ORDER BY num_files DESC
415415
LIMIT 10`,
416+
417+
// Squash with non-squashable joins
418+
`SELECT * FROM refs NATURAL JOIN blobs`,
419+
`SELECT * FROM remotes NATURAL JOIN commits`,
420+
`SELECT *
421+
FROM repositories
422+
NATURAL JOIN refs
423+
NATURAL JOIN blobs
424+
NATURAL JOIN files`,
416425
}
417426

418427
for _, q := range queries {

internal/rule/squashjoins.go

Lines changed: 154 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package rule
22

33
import (
4+
"fmt"
5+
"reflect"
6+
47
"github.com/src-d/gitbase"
58
errors "gopkg.in/src-d/go-errors.v1"
69
"gopkg.in/src-d/go-mysql-server.v0/sql"
@@ -50,7 +53,7 @@ func SquashJoins(
5053
return n, nil
5154
}
5255

53-
return buildSquashedTable(t.tables, t.filters, t.columns, t.indexes)
56+
return buildSquashedTable(a, t.tables, t.filters, t.columns, t.indexes)
5457
})
5558

5659
if err != nil {
@@ -214,7 +217,13 @@ func rearrange(join *plan.InnerJoin, squashedTable *joinedTables) sql.Node {
214217

215218
var errInvalidIteratorChain = errors.NewKind("invalid iterator to chain with %s: %T")
216219

220+
type unsquashableTable struct {
221+
table sql.Table
222+
filters []sql.Expression
223+
}
224+
217225
func buildSquashedTable(
226+
a *analyzer.Analyzer,
218227
tables []sql.Table,
219228
filters []sql.Expression,
220229
columns []string,
@@ -229,9 +238,28 @@ func buildSquashedTable(
229238
index = idx
230239
}
231240

241+
tablesByName := make(map[string]sql.Table)
242+
for _, t := range tables {
243+
tablesByName[t.Name()] = t
244+
}
245+
246+
var usedTables []string
247+
var unsquashable []unsquashableTable
248+
addUnsquashable := func(tableName string) {
249+
var f []sql.Expression
250+
f, filters = filtersForTables(filters, usedTables...)
251+
unsquashable = append(unsquashable, unsquashableTable{
252+
table: tablesByName[tableName],
253+
filters: f,
254+
})
255+
}
256+
232257
var iter gitbase.ChainableIter
258+
var squashedTables []string
259+
233260
var err error
234261
for _, t := range tableNames {
262+
usedTables = append(usedTables, t)
235263
switch t {
236264
case gitbase.RepositoriesTableName:
237265
switch iter.(type) {
@@ -245,6 +273,7 @@ func buildSquashedTable(
245273
if err != nil {
246274
return nil, err
247275
}
276+
248277
iter = gitbase.NewAllReposIter(f)
249278
default:
250279
return nil, errInvalidIteratorChain.New("repositories", iter)
@@ -274,9 +303,11 @@ func buildSquashedTable(
274303
if err != nil {
275304
return nil, err
276305
}
306+
277307
iter = gitbase.NewAllRemotesIter(f)
278308
default:
279-
return nil, errInvalidIteratorChain.New("remotes", iter)
309+
addUnsquashable(gitbase.RemotesTableName)
310+
continue
280311
}
281312
case gitbase.ReferencesTableName:
282313
switch it := iter.(type) {
@@ -323,7 +354,8 @@ func buildSquashedTable(
323354
iter = gitbase.NewIndexRefsIter(f, index)
324355
}
325356
default:
326-
return nil, errInvalidIteratorChain.New("refs", iter)
357+
addUnsquashable(gitbase.ReferencesTableName)
358+
continue
327359
}
328360
case gitbase.RefCommitsTableName:
329361
switch it := iter.(type) {
@@ -375,7 +407,8 @@ func buildSquashedTable(
375407
iter = gitbase.NewAllRefCommitsIter(f)
376408
}
377409
default:
378-
return nil, errInvalidIteratorChain.New("ref_commits", iter)
410+
addUnsquashable(gitbase.RefCommitsTableName)
411+
continue
379412
}
380413
case gitbase.CommitsTableName:
381414
switch it := iter.(type) {
@@ -435,7 +468,8 @@ func buildSquashedTable(
435468
iter = gitbase.NewAllCommitsIter(f, false)
436469
}
437470
default:
438-
return nil, errInvalidIteratorChain.New("commits", iter)
471+
addUnsquashable(gitbase.CommitsTableName)
472+
continue
439473
}
440474
case gitbase.CommitTreesTableName:
441475
switch it := iter.(type) {
@@ -517,7 +551,8 @@ func buildSquashedTable(
517551
iter = gitbase.NewAllCommitTreesIter(f)
518552
}
519553
default:
520-
return nil, errInvalidIteratorChain.New("commit_trees", iter)
554+
addUnsquashable(gitbase.CommitTreesTableName)
555+
continue
521556
}
522557
case gitbase.CommitBlobsTableName:
523558
switch it := iter.(type) {
@@ -580,7 +615,8 @@ func buildSquashedTable(
580615
iter = gitbase.NewAllCommitBlobsIter(f)
581616
}
582617
default:
583-
return nil, errInvalidIteratorChain.New("commit_blobs", iter)
618+
addUnsquashable(gitbase.CommitBlobsTableName)
619+
continue
584620
}
585621
case gitbase.TreeEntriesTableName:
586622
switch it := iter.(type) {
@@ -641,7 +677,8 @@ func buildSquashedTable(
641677
iter = gitbase.NewAllTreeEntriesIter(f)
642678
}
643679
default:
644-
return nil, errInvalidIteratorChain.New("tree_entries", iter)
680+
addUnsquashable(gitbase.TreeEntriesTableName)
681+
continue
645682
}
646683
case gitbase.BlobsTableName:
647684
readContent := stringInSlice(columns, "blob_content")
@@ -687,7 +724,8 @@ func buildSquashedTable(
687724

688725
iter = gitbase.NewTreeEntryBlobsIter(it, f, readContent)
689726
default:
690-
return nil, errInvalidIteratorChain.New("blobs", iter)
727+
addUnsquashable(gitbase.BlobsTableName)
728+
continue
691729
}
692730
case gitbase.CommitFilesTableName:
693731
switch it := iter.(type) {
@@ -734,7 +772,8 @@ func buildSquashedTable(
734772
iter = gitbase.NewAllCommitFilesIter(f)
735773
}
736774
default:
737-
return nil, errInvalidIteratorChain.New("commit_files", iter)
775+
addUnsquashable(gitbase.CommitFilesTableName)
776+
continue
738777
}
739778
case gitbase.FilesTableName:
740779
readContent := stringInSlice(columns, "blob_content")
@@ -754,9 +793,12 @@ func buildSquashedTable(
754793

755794
iter = gitbase.NewCommitFileFilesIter(it, f, readContent)
756795
default:
757-
return nil, errInvalidIteratorChain.New("files", iter)
796+
addUnsquashable(gitbase.FilesTableName)
797+
continue
758798
}
759799
}
800+
801+
squashedTables = append(squashedTables, t)
760802
}
761803

762804
var originalSchema sql.Schema
@@ -771,17 +813,61 @@ func buildSquashedTable(
771813
indexedTables = []string{firstTable}
772814
}
773815

816+
var squashMapping = mapping
817+
if len(unsquashable) > 0 {
818+
squashMapping = nil
819+
}
820+
821+
var nonSquashedFilters []sql.Expression
822+
for _, t := range unsquashable {
823+
nonSquashedFilters = append(nonSquashedFilters, t.filters...)
824+
}
825+
nonSquashedFilters = append(nonSquashedFilters, filters...)
826+
squashedFilters := filterDiff(allFilters, nonSquashedFilters)
827+
774828
table := gitbase.NewSquashedTable(
775829
iter,
776-
mapping,
777-
allFilters,
830+
squashMapping,
831+
squashedFilters,
778832
indexedTables,
779-
tableNames...,
833+
squashedTables...,
780834
)
781835
var node sql.Node = plan.NewResolvedTable(table)
782836

837+
if len(unsquashable) > 0 {
838+
for _, t := range unsquashable {
839+
var table sql.Node = plan.NewResolvedTable(t.table)
840+
if a.Parallelism > 1 {
841+
table = plan.NewExchange(a.Parallelism, table)
842+
}
843+
844+
if len(t.filters) > 0 {
845+
f, err := fixFieldIndexes(
846+
expression.JoinAnd(t.filters...),
847+
append(t.table.Schema(), node.Schema()...),
848+
)
849+
if err != nil {
850+
return nil, err
851+
}
852+
853+
node = plan.NewInnerJoin(
854+
table,
855+
node,
856+
f,
857+
)
858+
} else {
859+
node = plan.NewCrossJoin(table, node)
860+
}
861+
}
862+
863+
node, err = projectSchema(node, originalSchema)
864+
if err != nil {
865+
return nil, err
866+
}
867+
}
868+
783869
if len(filters) > 0 {
784-
f, err := fixFieldIndexes(expression.JoinAnd(filters...), iter.Schema())
870+
f, err := fixFieldIndexes(expression.JoinAnd(filters...), originalSchema)
785871
if err != nil {
786872
return nil, err
787873
}
@@ -791,6 +877,39 @@ func buildSquashedTable(
791877
return node, nil
792878
}
793879

880+
var errUnsquashableFieldNotFound = errors.NewKind("unable to unsquash table, column %s.%s not found")
881+
882+
// projectSchema wraps the node in a Project node that has the same schema as
883+
// the one provided.
884+
func projectSchema(node sql.Node, schema sql.Schema) (sql.Node, error) {
885+
if node.Schema().Equals(schema) {
886+
return node, nil
887+
}
888+
889+
var columnIndexes = make(map[string]int)
890+
for i, col := range node.Schema() {
891+
columnIndexes[fmt.Sprintf("%s.%s", col.Source, col.Name)] = i + 1
892+
}
893+
894+
var project = make([]sql.Expression, len(schema))
895+
for i, col := range schema {
896+
idx := columnIndexes[fmt.Sprintf("%s.%s", col.Source, col.Name)]
897+
if idx <= 0 {
898+
return nil, errUnsquashableFieldNotFound.New(col.Source, col.Name)
899+
}
900+
901+
project[i] = expression.NewGetFieldWithTable(
902+
idx-1,
903+
col.Type,
904+
col.Source,
905+
col.Name,
906+
col.Nullable,
907+
)
908+
}
909+
910+
return plan.NewProject(project, node), nil
911+
}
912+
794913
// buildSchemaMapping returns a mapping to convert the actual schema into the
795914
// original schema. If both schemas are equal, nil will be returned.
796915
func buildSchemaMapping(original, actual sql.Schema) []int {
@@ -1409,3 +1528,23 @@ func fixFieldIndexes(e sql.Expression, schema sql.Schema) (sql.Expression, error
14091528
return nil, analyzer.ErrColumnTableNotFound.New(gf.Table(), gf.Name())
14101529
})
14111530
}
1531+
1532+
func filterDiff(a, b []sql.Expression) []sql.Expression {
1533+
var result []sql.Expression
1534+
1535+
for _, fa := range a {
1536+
var found bool
1537+
for _, fb := range b {
1538+
if reflect.DeepEqual(fa, fb) {
1539+
found = true
1540+
break
1541+
}
1542+
}
1543+
1544+
if !found {
1545+
result = append(result, fa)
1546+
}
1547+
}
1548+
1549+
return result
1550+
}

0 commit comments

Comments
 (0)