@@ -31,6 +31,9 @@ func SquashJoins(
31
31
defer span .Finish ()
32
32
33
33
a .Log ("squashing joins, node of type %T" , n )
34
+
35
+ projectSquashes := countProjectSquashes (n )
36
+
34
37
n , err := n .TransformUp (func (n sql.Node ) (sql.Node , error ) {
35
38
join , ok := n .(* plan.InnerJoin )
36
39
if ! ok {
@@ -39,18 +42,103 @@ func SquashJoins(
39
42
40
43
return squashJoin (join )
41
44
})
45
+
42
46
if err != nil {
43
47
return nil , err
44
48
}
45
49
46
- return n .TransformUp (func (n sql.Node ) (sql.Node , error ) {
50
+ n , err = n .TransformUp (func (n sql.Node ) (sql.Node , error ) {
47
51
t , ok := n .(* joinedTables )
48
52
if ! ok {
49
53
return n , nil
50
54
}
51
55
52
56
return buildSquashedTable (t .tables , t .filters , t .columns , t .indexes )
53
57
})
58
+
59
+ if err != nil {
60
+ return nil , err
61
+ }
62
+
63
+ return n .TransformUp (func (n sql.Node ) (sql.Node , error ) {
64
+ if projectSquashes <= 0 {
65
+ return n , nil
66
+ }
67
+
68
+ project , ok := n .(* plan.Project )
69
+ if ! ok {
70
+ return n , nil
71
+ }
72
+
73
+ child , ok := project .Child .(* plan.Project )
74
+ if ! ok {
75
+ return n , nil
76
+ }
77
+
78
+ squashedProject , err := squashProjects (project , child )
79
+ if err != nil {
80
+ return nil , err
81
+ }
82
+
83
+ projectSquashes --
84
+ return squashedProject , nil
85
+ })
86
+ }
87
+
88
+ func countProjectSquashes (n sql.Node ) int {
89
+ var squashableProjects int
90
+ plan .Inspect (n , func (node sql.Node ) bool {
91
+ if project , ok := node .(* plan.Project ); ok {
92
+ if _ , ok := project .Child .(* plan.InnerJoin ); ok {
93
+ squashableProjects ++
94
+ }
95
+ }
96
+
97
+ return true
98
+ })
99
+
100
+ return squashableProjects - 1
101
+ }
102
+
103
+ // ErrWrongProjection is raised if a plan.Project node contains a wrong expression.
104
+ var ErrWrongProjection = errors .NewKind ("wrong expression found in project node %s" )
105
+
106
+ func squashProjects (parent , child * plan.Project ) (sql.Node , error ) {
107
+ projections := []sql.Expression {}
108
+ for _ , expr := range parent .Expressions () {
109
+ parentField , ok := expr .(* expression.GetField )
110
+ if ! ok {
111
+ return nil , ErrWrongProjection .New (parent .String ())
112
+ }
113
+
114
+ index := parentField .Index ()
115
+ for _ , e := range child .Expressions () {
116
+ childField , ok := e .(* expression.GetField )
117
+ if ! ok {
118
+ return nil , ErrWrongProjection .New (child .String ())
119
+ }
120
+
121
+ if referenceSameColumn (parentField , childField ) {
122
+ index = childField .Index ()
123
+ }
124
+ }
125
+
126
+ projection := expression .NewGetFieldWithTable (
127
+ index ,
128
+ parentField .Type (),
129
+ parentField .Table (),
130
+ parentField .Name (),
131
+ parentField .IsNullable (),
132
+ )
133
+
134
+ projections = append (projections , projection )
135
+ }
136
+
137
+ return plan .NewProject (projections , child .Child ), nil
138
+ }
139
+
140
+ func referenceSameColumn (parent , child * expression.GetField ) bool {
141
+ return parent .Name () == child .Name () && parent .Table () == child .Table ()
54
142
}
55
143
56
144
func squashJoin (join * plan.InnerJoin ) (sql.Node , error ) {
0 commit comments