@@ -1009,10 +1009,14 @@ func isJoinCondSquashable(join *plan.InnerJoin) bool {
1009
1009
leftTables := findLeafTables (join .Left )
1010
1010
rightTables := findLeafTables (join .Right )
1011
1011
1012
+ var squashedTables []string
1012
1013
if len (rightTables ) == 1 {
1014
+ squashedTables = findSquashedTables (join .Left )
1013
1015
leftTables , rightTables = rightTables , leftTables
1014
1016
} else if len (leftTables ) != 1 {
1015
1017
return false
1018
+ } else {
1019
+ squashedTables = findSquashedTables (join .Right )
1016
1020
}
1017
1021
1018
1022
lt := leftTables [0 ]
@@ -1021,8 +1025,33 @@ func isJoinCondSquashable(join *plan.InnerJoin) bool {
1021
1025
continue
1022
1026
}
1023
1027
1028
+ var cond = join .Cond
1029
+ // if the right table is squashed, we might need to rewrite some column
1030
+ // tables in order to find the condition, since natural joins deduplicate
1031
+ // columns with the same name.
1032
+ if stringInSlice (squashedTables , rt ) {
1033
+ c , err := join .Cond .TransformUp (func (e sql.Expression ) (sql.Expression , error ) {
1034
+ gf , ok := e .(* expression.GetField )
1035
+ if ok && gf .Table () != rt && gf .Table () != lt {
1036
+ if tableHasColumn (rt , gf .Name ()) {
1037
+ return expression .NewGetFieldWithTable (
1038
+ gf .Index (),
1039
+ gf .Type (),
1040
+ rt ,
1041
+ gf .Name (),
1042
+ gf .IsNullable (),
1043
+ ), nil
1044
+ }
1045
+ }
1046
+ return e , nil
1047
+ })
1048
+ if err == nil {
1049
+ cond = c
1050
+ }
1051
+ }
1052
+
1024
1053
t1 , t2 := orderedTablePair (lt , rt )
1025
- if hasChainableJoinCondition (join . Cond , t1 , t2 ) {
1054
+ if hasChainableJoinCondition (cond , t1 , t2 ) {
1026
1055
return true
1027
1056
}
1028
1057
}
@@ -1064,6 +1093,20 @@ func findLeafTables(n sql.Node) []string {
1064
1093
return tables
1065
1094
}
1066
1095
1096
+ func findSquashedTables (n sql.Node ) []string {
1097
+ var tables []string
1098
+ plan .Inspect (n , func (n sql.Node ) bool {
1099
+ switch n := n .(type ) {
1100
+ case * joinedTables :
1101
+ tables = orderedTableNames (n .tables )
1102
+ return false
1103
+ default :
1104
+ return true
1105
+ }
1106
+ })
1107
+ return tables
1108
+ }
1109
+
1067
1110
func exprToFilters (expr sql.Expression ) (filters []sql.Expression ) {
1068
1111
if expr , ok := expr .(* expression.And ); ok {
1069
1112
return append (exprToFilters (expr .Left ), exprToFilters (expr .Right )... )
@@ -1548,3 +1591,36 @@ func filterDiff(a, b []sql.Expression) []sql.Expression {
1548
1591
1549
1592
return result
1550
1593
}
1594
+
1595
+ func tableHasColumn (t , col string ) bool {
1596
+ return tableSchema (t ).Contains (col , t )
1597
+ }
1598
+
1599
+ func tableSchema (t string ) sql.Schema {
1600
+ switch t {
1601
+ case gitbase .RepositoriesTableName :
1602
+ return gitbase .RepositoriesSchema
1603
+ case gitbase .ReferencesTableName :
1604
+ return gitbase .RefsSchema
1605
+ case gitbase .RemotesTableName :
1606
+ return gitbase .RemotesSchema
1607
+ case gitbase .RefCommitsTableName :
1608
+ return gitbase .RefCommitsSchema
1609
+ case gitbase .CommitsTableName :
1610
+ return gitbase .CommitsSchema
1611
+ case gitbase .CommitTreesTableName :
1612
+ return gitbase .CommitTreesSchema
1613
+ case gitbase .CommitBlobsTableName :
1614
+ return gitbase .CommitBlobsSchema
1615
+ case gitbase .CommitFilesTableName :
1616
+ return gitbase .CommitFilesSchema
1617
+ case gitbase .TreeEntriesTableName :
1618
+ return gitbase .TreeEntriesSchema
1619
+ case gitbase .BlobsTableName :
1620
+ return gitbase .BlobsSchema
1621
+ case gitbase .FilesTableName :
1622
+ return gitbase .FilesSchema
1623
+ default :
1624
+ return nil
1625
+ }
1626
+ }
0 commit comments