Skip to content

Commit d1560a8

Browse files
authored
add analyzer package. (#26)
* Add analyzer.Analyzer. * Add rule to resolve table.
1 parent 9fd1090 commit d1560a8

File tree

10 files changed

+319
-4
lines changed

10 files changed

+319
-4
lines changed

git/commits.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ func (commitsRelation) Schema() sql.Schema {
3838
}
3939

4040
func (r *commitsRelation) TransformUp(f func(sql.Node) sql.Node) sql.Node {
41-
return f(newCommitsRelation(r.r))
41+
return f(r)
4242
}
4343

4444
func (r *commitsRelation) TransformExpressionsUp(f func(sql.Expression) sql.Expression) sql.Node {

mem/table.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func (t *Table) RowIter() (sql.RowIter, error) {
4242
}
4343

4444
func (t *Table) TransformUp(f func(sql.Node) sql.Node) sql.Node {
45-
return f(NewTable(t.name, t.schema))
45+
return f(t)
4646
}
4747

4848
func (t *Table) TransformExpressionsUp(f func(sql.Expression) sql.Expression) sql.Node {

sql/analyzer/analyzer.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package analyzer
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"reflect"
7+
8+
"github.com/mvader/gitql/sql"
9+
)
10+
11+
const maxAnalysisIterations = 1000
12+
13+
type Analyzer struct {
14+
Rules []Rule
15+
Catalog sql.Catalog
16+
CurrentDatabase string
17+
}
18+
19+
type Rule struct {
20+
Name string
21+
Apply func(*Analyzer, sql.Node) sql.Node
22+
}
23+
24+
func New(catalog sql.Catalog) *Analyzer {
25+
return &Analyzer{
26+
Rules: DefaultRules,
27+
Catalog: catalog,
28+
}
29+
}
30+
31+
func (a *Analyzer) Analyze(n sql.Node) (sql.Node, error) {
32+
prev := n
33+
cur := a.analyzeOnce(n)
34+
i := 0
35+
for !reflect.DeepEqual(prev, cur) {
36+
prev = cur
37+
cur = a.analyzeOnce(n)
38+
i += 1
39+
if i >= maxAnalysisIterations {
40+
return cur, fmt.Errorf("exceeded max analysis iterations (%d)", maxAnalysisIterations)
41+
}
42+
}
43+
44+
err := a.validate(cur)
45+
if err != nil {
46+
return cur, err
47+
}
48+
49+
return cur, nil
50+
}
51+
52+
func (a *Analyzer) analyzeOnce(n sql.Node) sql.Node {
53+
result := n
54+
for _, rule := range a.Rules {
55+
result = rule.Apply(a, result)
56+
}
57+
return result
58+
}
59+
60+
func (a *Analyzer) validate(n sql.Node) error {
61+
if !n.Resolved() {
62+
return errors.New("plan is not resolved")
63+
}
64+
65+
return nil
66+
}

sql/analyzer/analyzer_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package analyzer_test
2+
3+
import (
4+
"fmt"
5+
"testing"
6+
7+
"github.com/mvader/gitql/mem"
8+
"github.com/mvader/gitql/sql"
9+
"github.com/mvader/gitql/sql/analyzer"
10+
"github.com/mvader/gitql/sql/expression"
11+
"github.com/mvader/gitql/sql/plan"
12+
"github.com/stretchr/testify/require"
13+
)
14+
15+
func TestAnalyzer_Analyze(t *testing.T) {
16+
assert := require.New(t)
17+
18+
table := mem.NewTable("mytable", sql.Schema{{"i", sql.Integer}})
19+
db := mem.NewDatabase("mydb")
20+
db.AddTable("mytable", table)
21+
22+
catalog := sql.Catalog{db}
23+
a := analyzer.New(catalog)
24+
a.CurrentDatabase = "mydb"
25+
26+
var notAnalyzed sql.Node = plan.NewUnresolvedRelation("mytable")
27+
analyzed, err := a.Analyze(notAnalyzed)
28+
assert.Nil(err)
29+
assert.Equal(table, analyzed)
30+
31+
notAnalyzed = plan.NewUnresolvedRelation("nonexistant")
32+
analyzed, err = a.Analyze(notAnalyzed)
33+
assert.NotNil(err)
34+
assert.Equal(notAnalyzed, analyzed)
35+
36+
analyzed, err = a.Analyze(table)
37+
assert.Nil(err)
38+
assert.Equal(table, analyzed)
39+
40+
notAnalyzed = plan.NewProject(
41+
[]sql.Expression{expression.NewUnresolvedColumn("i")},
42+
plan.NewUnresolvedRelation("mytable"),
43+
)
44+
analyzed, err = a.Analyze(notAnalyzed)
45+
expected := plan.NewProject(
46+
[]sql.Expression{expression.NewGetField(0, sql.Integer, "i")},
47+
table,
48+
)
49+
assert.Nil(err)
50+
assert.Equal(expected, analyzed)
51+
}
52+
53+
func TestAnalyzer_Analyze_MaxIterations(t *testing.T) {
54+
assert := require.New(t)
55+
56+
catalog := sql.Catalog{}
57+
a := analyzer.New(catalog)
58+
a.CurrentDatabase = "mydb"
59+
60+
i := 0
61+
a.Rules = []analyzer.Rule{{
62+
"infinite",
63+
func(a *analyzer.Analyzer, n sql.Node) sql.Node {
64+
i += 1
65+
return plan.NewUnresolvedRelation(fmt.Sprintf("rel%d", i))
66+
},
67+
}}
68+
69+
notAnalyzed := plan.NewUnresolvedRelation("mytable")
70+
analyzed, err := a.Analyze(notAnalyzed)
71+
assert.NotNil(err)
72+
assert.Equal(plan.NewUnresolvedRelation("rel1001"), analyzed)
73+
}

sql/analyzer/rules.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package analyzer
2+
3+
import (
4+
"fmt"
5+
"github.com/mvader/gitql/sql"
6+
"github.com/mvader/gitql/sql/expression"
7+
"github.com/mvader/gitql/sql/plan"
8+
)
9+
10+
var DefaultRules = []Rule{
11+
{"resolve_tables", resolveTables},
12+
{"resolve_columns", resolveColumns},
13+
}
14+
15+
func resolveTables(a *Analyzer, n sql.Node) sql.Node {
16+
return n.TransformUp(func(n sql.Node) sql.Node {
17+
t, ok := n.(*plan.UnresolvedRelation)
18+
if !ok {
19+
return n
20+
}
21+
22+
rt, err := a.Catalog.Table(a.CurrentDatabase, t.Name)
23+
if err != nil {
24+
return n
25+
}
26+
27+
return rt
28+
})
29+
}
30+
31+
func resolveColumns(a *Analyzer, n sql.Node) sql.Node {
32+
if n.Resolved() {
33+
return n
34+
}
35+
36+
if len(n.Children()) != 1 {
37+
return n
38+
}
39+
40+
child := n.Children()[0]
41+
if !child.Resolved() {
42+
return n
43+
}
44+
45+
colMap := map[string]*expression.GetField{}
46+
for idx, child := range child.Schema() {
47+
colMap[child.Name] = expression.NewGetField(idx, child.Type, child.Name)
48+
}
49+
50+
return n.TransformExpressionsUp(func(e sql.Expression) sql.Expression {
51+
uc, ok := e.(*expression.UnresolvedColumn)
52+
if !ok {
53+
return e
54+
}
55+
56+
gf, ok := colMap[uc.Name()]
57+
if !ok {
58+
return e
59+
}
60+
61+
return gf
62+
})
63+
}

sql/analyzer/rules_test.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package analyzer_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/mvader/gitql/mem"
7+
"github.com/mvader/gitql/sql"
8+
"github.com/mvader/gitql/sql/analyzer"
9+
"github.com/mvader/gitql/sql/expression"
10+
"github.com/mvader/gitql/sql/plan"
11+
"github.com/stretchr/testify/assert"
12+
)
13+
14+
func Test_resolveTables(t *testing.T) {
15+
assert := assert.New(t)
16+
17+
f := getRule("resolve_tables")
18+
19+
table := mem.NewTable("mytable", sql.Schema{{"i", sql.Integer}})
20+
db := mem.NewDatabase("mydb")
21+
db.AddTable("mytable", table)
22+
23+
catalog := sql.Catalog{db}
24+
25+
a := analyzer.New(catalog)
26+
a.Rules = []analyzer.Rule{f}
27+
28+
a.CurrentDatabase = "mydb"
29+
var notAnalyzed sql.Node = plan.NewUnresolvedRelation("mytable")
30+
analyzed := f.Apply(a, notAnalyzed)
31+
assert.Equal(table, analyzed)
32+
33+
notAnalyzed = plan.NewUnresolvedRelation("nonexistant")
34+
analyzed = f.Apply(a, notAnalyzed)
35+
assert.Equal(notAnalyzed, analyzed)
36+
37+
analyzed = f.Apply(a, table)
38+
assert.Equal(table, analyzed)
39+
40+
}
41+
42+
func Test_resolveTables_Nested(t *testing.T) {
43+
assert := assert.New(t)
44+
45+
f := getRule("resolve_tables")
46+
47+
table := mem.NewTable("mytable", sql.Schema{{"i", sql.Integer}})
48+
db := mem.NewDatabase("mydb")
49+
db.AddTable("mytable", table)
50+
51+
catalog := sql.Catalog{db}
52+
53+
a := analyzer.New(catalog)
54+
a.Rules = []analyzer.Rule{f}
55+
a.CurrentDatabase = "mydb"
56+
57+
notAnalyzed := plan.NewProject(
58+
[]sql.Expression{expression.NewGetField(0, sql.Integer, "i")},
59+
plan.NewUnresolvedRelation("mytable"),
60+
)
61+
analyzed := f.Apply(a, notAnalyzed)
62+
expected := plan.NewProject(
63+
[]sql.Expression{expression.NewGetField(0, sql.Integer, "i")},
64+
table,
65+
)
66+
assert.Equal(expected, analyzed)
67+
}
68+
69+
func getRule(name string) analyzer.Rule {
70+
for _, rule := range analyzer.DefaultRules {
71+
if rule.Name == name {
72+
return rule
73+
}
74+
}
75+
panic("missing rule")
76+
}

sql/catalog.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package sql
2+
3+
import (
4+
"fmt"
5+
)
6+
7+
type Catalog []Database
8+
9+
func (c Catalog) Database(name string) (Database, error) {
10+
for _, db := range []Database(c) {
11+
if db.Name() == name {
12+
return db, nil
13+
}
14+
}
15+
16+
return nil, fmt.Errorf("database not found: %s", name)
17+
}
18+
19+
func (c Catalog) Table(dbName string, tableName string) (PhysicalRelation, error) {
20+
db, err := c.Database(dbName)
21+
if err != nil {
22+
return nil, err
23+
}
24+
25+
tables := db.Relations()
26+
table, found := tables[tableName]
27+
if !found {
28+
return nil, fmt.Errorf("table not found: %s", tableName)
29+
}
30+
31+
return table, nil
32+
}

sql/plan/project.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ func (p *Project) TransformExpressionsUp(f func(sql.Expression) sql.Expression)
6464
c := p.UnaryNode.Child.TransformExpressionsUp(f)
6565
es := []sql.Expression{}
6666
for _, e := range p.expressions {
67-
es = append(es, e.TransformUp(f))
67+
te := e.TransformUp(f)
68+
es = append(es, te)
6869
}
6970
n := NewProject(es, c)
7071

sql/plan/unresolved.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ type UnresolvedRelation struct {
1010
Name string
1111
}
1212

13+
func NewUnresolvedRelation(name string) *UnresolvedRelation {
14+
return &UnresolvedRelation{name}
15+
}
16+
1317
func (*UnresolvedRelation) Resolved() bool {
1418
return false
1519
}

sql/plan/unresolved_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@ import (
99

1010
func TestUnresolvedRelation(t *testing.T) {
1111
assert := assert.New(t)
12-
var r sql.Node = &UnresolvedRelation{"test_table"}
12+
var r sql.Node = NewUnresolvedRelation("test_table")
1313
assert.NotNil(r)
1414
}

0 commit comments

Comments
 (0)