Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit d42e83a

Browse files
erizocosmicoajnavarro
authored andcommitted
sql/analyzer: fix ambiguous table after natural join
Signed-off-by: Miguel Molina <[email protected]> (cherry picked from commit ac92402)
1 parent 27587a5 commit d42e83a

File tree

3 files changed

+117
-23
lines changed

3 files changed

+117
-23
lines changed

engine_test.go

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -160,29 +160,29 @@ var queries = []struct {
160160
},
161161
},
162162
{
163-
"SELECT text > 2 FROM tabletest",
163+
"SELECT s > 2 FROM tabletest",
164164
[]sql.Row{
165165
{false},
166166
{false},
167167
{false},
168168
},
169169
},
170170
{
171-
"SELECT * FROM tabletest WHERE text > 0",
171+
"SELECT * FROM tabletest WHERE s > 0",
172172
nil,
173173
},
174174
{
175-
"SELECT * FROM tabletest WHERE text = 0",
175+
"SELECT * FROM tabletest WHERE s = 0",
176176
[]sql.Row{
177-
{"a", int32(1)},
178-
{"b", int32(2)},
179-
{"c", int32(3)},
177+
{int64(1), "first row"},
178+
{int64(2), "second row"},
179+
{int64(3), "third row"},
180180
},
181181
},
182182
{
183-
"SELECT * FROM tabletest WHERE text = 'a'",
183+
"SELECT * FROM tabletest WHERE s = 'first row'",
184184
[]sql.Row{
185-
{"a", int32(1)},
185+
{int64(1), "first row"},
186186
},
187187
},
188188
{
@@ -211,15 +211,15 @@ var queries = []struct {
211211
{
212212
`SELECT * FROM tabletest, mytable mt INNER JOIN othertable ot ON mt.i = ot.i2`,
213213
[]sql.Row{
214-
{"a", int32(1), int64(1), "first row", "third", int64(1)},
215-
{"a", int32(1), int64(2), "second row", "second", int64(2)},
216-
{"a", int32(1), int64(3), "third row", "first", int64(3)},
217-
{"b", int32(2), int64(1), "first row", "third", int64(1)},
218-
{"b", int32(2), int64(2), "second row", "second", int64(2)},
219-
{"b", int32(2), int64(3), "third row", "first", int64(3)},
220-
{"c", int32(3), int64(1), "first row", "third", int64(1)},
221-
{"c", int32(3), int64(2), "second row", "second", int64(2)},
222-
{"c", int32(3), int64(3), "third row", "first", int64(3)},
214+
{int64(1), "first row", int64(1), "first row", "third", int64(1)},
215+
{int64(1), "first row", int64(2), "second row", "second", int64(2)},
216+
{int64(1), "first row", int64(3), "third row", "first", int64(3)},
217+
{int64(2), "second row", int64(1), "first row", "third", int64(1)},
218+
{int64(2), "second row", int64(2), "second row", "second", int64(2)},
219+
{int64(2), "second row", int64(3), "third row", "first", int64(3)},
220+
{int64(3), "third row", int64(1), "first row", "third", int64(1)},
221+
{int64(3), "third row", int64(2), "second row", "second", int64(2)},
222+
{int64(3), "third row", int64(3), "third row", "first", int64(3)},
223223
},
224224
},
225225
{
@@ -741,15 +741,15 @@ func newEngineWithParallelism(t *testing.T, parallelism int) *sqle.Engine {
741741
)
742742

743743
table3 := mem.NewPartitionedTable("tabletest", sql.Schema{
744-
{Name: "text", Type: sql.Text, Source: "tabletest"},
745-
{Name: "number", Type: sql.Int32, Source: "tabletest"},
744+
{Name: "i", Type: sql.Int32, Source: "tabletest"},
745+
{Name: "s", Type: sql.Text, Source: "tabletest"},
746746
}, testNumPartitions)
747747

748748
insertRows(
749749
t, table3,
750-
sql.NewRow("a", int32(1)),
751-
sql.NewRow("b", int32(2)),
752-
sql.NewRow("c", int32(3)),
750+
sql.NewRow(int64(1), "first row"),
751+
sql.NewRow(int64(2), "second row"),
752+
sql.NewRow(int64(3), "third row"),
753753
)
754754

755755
db := mem.NewDatabase("mydb")

sql/analyzer/resolve_columns.go

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ func qualifyColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error)
5050
}
5151
}
5252

53+
var projects, seenProjects int
54+
plan.Inspect(n, func(n sql.Node) bool {
55+
if _, ok := n.(*plan.Project); ok {
56+
projects++
57+
}
58+
return true
59+
})
60+
5361
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
5462
a.Log("transforming node of type: %T", n)
5563
switch n := n.(type) {
@@ -68,7 +76,7 @@ func qualifyColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error)
6876
indexCols(name, n.Schema())
6977
}
7078

71-
return n.TransformExpressionsUp(func(e sql.Expression) (sql.Expression, error) {
79+
result, err := n.TransformExpressionsUp(func(e sql.Expression) (sql.Expression, error) {
7280
a.Log("transforming expression of type: %T", e)
7381
switch col := e.(type) {
7482
case *expression.UnresolvedColumn:
@@ -135,6 +143,51 @@ func qualifyColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error)
135143

136144
return e, nil
137145
})
146+
147+
if err != nil {
148+
return nil, err
149+
}
150+
151+
// We should ignore the topmost project, because some nodes are
152+
// reordered, such as Sort, and they would not be resolved well.
153+
if n, ok := result.(*plan.Project); ok && projects-seenProjects > 1 {
154+
seenProjects++
155+
156+
// We need to modify the indexed columns to only contain what is
157+
// projected in this project. If the column is not qualified by any
158+
// table, just keep the ones that are currently in the index.
159+
// If it is, then just make those tables available for the column.
160+
// If we don't do this, columns that are not projected will be
161+
// available in this step and may cause false errors or unintended
162+
// results.
163+
var projected = make(map[string][]string)
164+
for _, p := range n.Projections {
165+
var table, col string
166+
switch p := p.(type) {
167+
case column:
168+
table = p.Table()
169+
col = p.Name()
170+
case *expression.GetField:
171+
table = p.Table()
172+
col = p.Name()
173+
default:
174+
continue
175+
}
176+
177+
if table != "" {
178+
projected[col] = append(projected[col], table)
179+
} else {
180+
projected[col] = append(projected[col], colIndex[col]...)
181+
}
182+
}
183+
184+
colIndex = make(map[string][]string)
185+
for col, tables := range projected {
186+
colIndex[col] = dedupStrings(tables)
187+
}
188+
}
189+
190+
return result, nil
138191
})
139192
}
140193

sql/analyzer/resolve_columns_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,47 @@ import (
1111
"gopkg.in/src-d/go-mysql-server.v0/sql/plan"
1212
)
1313

14+
func TestQualifyColumnsProject(t *testing.T) {
15+
require := require.New(t)
16+
17+
table := mem.NewTable("foo", sql.Schema{
18+
{Name: "a", Type: sql.Text, Source: "foo"},
19+
{Name: "b", Type: sql.Text, Source: "foo"},
20+
})
21+
22+
node := plan.NewProject(
23+
[]sql.Expression{
24+
expression.NewUnresolvedColumn("a"),
25+
expression.NewUnresolvedColumn("b"),
26+
},
27+
plan.NewProject(
28+
[]sql.Expression{
29+
expression.NewUnresolvedQualifiedColumn("foo", "a"),
30+
},
31+
plan.NewResolvedTable(table),
32+
),
33+
)
34+
35+
result, err := qualifyColumns(sql.NewEmptyContext(), NewDefault(nil), node)
36+
require.NoError(err)
37+
38+
expected := plan.NewProject(
39+
[]sql.Expression{
40+
expression.NewUnresolvedQualifiedColumn("foo", "a"),
41+
// b is not qualified because it's not projected
42+
expression.NewUnresolvedColumn("b"),
43+
},
44+
plan.NewProject(
45+
[]sql.Expression{
46+
expression.NewUnresolvedQualifiedColumn("foo", "a"),
47+
},
48+
plan.NewResolvedTable(table),
49+
),
50+
)
51+
52+
require.Equal(expected, result)
53+
}
54+
1455
func TestMisusedAlias(t *testing.T) {
1556
require := require.New(t)
1657
f := getRule("resolve_columns")

0 commit comments

Comments
 (0)