Skip to content

Commit 42bf507

Browse files
authored
add TransformExpressionsUp. (#29)
* Fix Sort to use Expressions.
1 parent cd8af01 commit 42bf507

File tree

12 files changed

+81
-48
lines changed

12 files changed

+81
-48
lines changed

git/commits.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ func (r *commitsRelation) TransformUp(f func(sql.Node) sql.Node) sql.Node {
4141
return f(newCommitsRelation(r.r))
4242
}
4343

44+
func (r *commitsRelation) TransformExpressionsUp(f func(sql.Expression) sql.Expression) sql.Node {
45+
return r
46+
}
47+
4448
func (r commitsRelation) RowIter() (sql.RowIter, error) {
4549
cIter, err := r.r.Commits()
4650
if err != nil {

mem/table.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ func (t *Table) TransformUp(f func(sql.Node) sql.Node) sql.Node {
4545
return f(NewTable(t.name, t.schema))
4646
}
4747

48+
func (t *Table) TransformExpressionsUp(f func(sql.Expression) sql.Expression) sql.Node {
49+
return t
50+
}
51+
4852
func (t *Table) Insert(values ...interface{}) error {
4953
if len(values) != len(t.schema) {
5054
return fmt.Errorf("insert expected %d values, got %d", len(t.schema), len(values))

sql/core.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ type Resolvable interface {
1212

1313
type Transformable interface {
1414
TransformUp(func(Node) Node) Node
15+
TransformExpressionsUp(func(Expression) Expression) Node
1516
}
1617

1718
type Node interface {

sql/expression/star.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,7 @@ func (Star) Name() string {
2424
func (Star) Eval(r sql.Row) interface{} {
2525
return "FAIL" //FIXME
2626
}
27+
28+
func (s *Star) TransformUp(f func(sql.Expression) sql.Expression) sql.Expression {
29+
return f(s)
30+
}

sql/parse/parse.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"strings"
88

99
"github.com/mvader/gitql/sql"
10+
"github.com/mvader/gitql/sql/expression"
1011
"github.com/mvader/gitql/sql/plan"
1112
)
1213

@@ -296,7 +297,7 @@ func parseOrderClause(q tokenQueue) ([]plan.SortField, error) {
296297
return nil, fmt.Errorf(`expecting "DESC", "ASC" or ",", received %q`, tk.Value)
297298
}
298299

299-
field = &plan.SortField{Column: tk.Value}
300+
field = &plan.SortField{Column: expression.NewUnresolvedColumn(tk.Value)}
300301
case KeywordToken:
301302
if field == nil {
302303
return nil, fmt.Errorf(`unexpected keyword %q, expecting identifier`, tk.Value)

sql/plan/filter.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ func (p *Filter) TransformUp(f func(sql.Node) sql.Node) sql.Node {
3737
return f(n)
3838
}
3939

40+
func (p *Filter) TransformExpressionsUp(f func(sql.Expression) sql.Expression) sql.Node {
41+
c := p.UnaryNode.Child.TransformExpressionsUp(f)
42+
e := p.expression.TransformUp(f)
43+
n := NewFilter(e, c)
44+
45+
return n
46+
}
47+
4048
type filterIter struct {
4149
f *Filter
4250
childIter sql.RowIter

sql/plan/limit.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ func (l *Limit) TransformUp(f func(sql.Node) sql.Node) sql.Node {
4141
return f(n)
4242
}
4343

44+
func (l *Limit) TransformExpressionsUp(f func(sql.Expression) sql.Expression) sql.Node {
45+
c := l.UnaryNode.Child.TransformExpressionsUp(f)
46+
n := NewLimit(l.size, c)
47+
48+
return n
49+
}
50+
4451
type limitIter struct {
4552
l *Limit
4653
currentPos int64

sql/plan/project.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,17 @@ func (p *Project) TransformUp(f func(sql.Node) sql.Node) sql.Node {
5151
return f(n)
5252
}
5353

54+
func (p *Project) TransformExpressionsUp(f func(sql.Expression) sql.Expression) sql.Node {
55+
c := p.UnaryNode.Child.TransformExpressionsUp(f)
56+
es := []sql.Expression{}
57+
for _, e := range p.expressions {
58+
es = append(es, e.TransformUp(f))
59+
}
60+
n := NewProject(es, c)
61+
62+
return n
63+
}
64+
5465
type iter struct {
5566
p *Project
5667
childIter sql.RowIter

sql/plan/sort.go

Lines changed: 26 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package plan
22

33
import (
4-
"fmt"
54
"io"
65
"sort"
76

@@ -10,9 +9,7 @@ import (
109

1110
type Sort struct {
1211
UnaryNode
13-
fieldIndexes []int
14-
fieldTypes []sql.Type
15-
sortFields []SortField
12+
sortFields []SortField
1613
}
1714

1815
type SortOrder byte
@@ -23,33 +20,14 @@ const (
2320
)
2421

2522
type SortField struct {
26-
Column string
23+
Column sql.Expression
2724
Order SortOrder
2825
}
2926

3027
func NewSort(sortFields []SortField, child sql.Node) *Sort {
31-
indexes := []int{}
32-
types := []sql.Type{}
33-
childSchema := child.Schema()
34-
for _, sortField := range sortFields {
35-
found := false
36-
for idx, field := range childSchema {
37-
if field.Name == sortField.Column {
38-
indexes = append(indexes, idx)
39-
types = append(types, field.Type)
40-
found = true
41-
break
42-
}
43-
}
44-
if found == false {
45-
panic(fmt.Errorf("Field %s not found in child", sortField.Column))
46-
}
47-
}
4828
return &Sort{
49-
fieldIndexes: indexes,
50-
fieldTypes: types,
51-
UnaryNode: UnaryNode{child},
52-
sortFields: sortFields,
29+
UnaryNode: UnaryNode{child},
30+
sortFields: sortFields,
5331
}
5432
}
5533

@@ -62,6 +40,7 @@ func (s *Sort) Schema() sql.Schema {
6240
}
6341

6442
func (s *Sort) RowIter() (sql.RowIter, error) {
43+
6544
i, err := s.UnaryNode.Child.RowIter()
6645
if err != nil {
6746
return nil, err
@@ -76,6 +55,17 @@ func (s *Sort) TransformUp(f func(sql.Node) sql.Node) sql.Node {
7655
return f(n)
7756
}
7857

58+
func (s *Sort) TransformExpressionsUp(f func(sql.Expression) sql.Expression) sql.Node {
59+
c := s.UnaryNode.Child.TransformExpressionsUp(f)
60+
sfs := []SortField{}
61+
for _, sf := range s.sortFields {
62+
sfs = append(sfs, SortField{sf.Column.TransformUp(f), sf.Order})
63+
}
64+
n := NewSort(sfs, c)
65+
66+
return n
67+
}
68+
7969
type sortIter struct {
8070
s *Sort
8171
childIter sql.RowIter
@@ -123,18 +113,16 @@ func (i *sortIter) computeSortedRows() error {
123113
rows = append(rows, childRow)
124114
}
125115
sort.Sort(&sorter{
126-
indexes: i.s.fieldIndexes,
127-
types: i.s.fieldTypes,
128-
rows: rows,
116+
sortFields: i.s.sortFields,
117+
rows: rows,
129118
})
130119
i.sortedRows = rows
131120
return nil
132121
}
133122

134123
type sorter struct {
135-
indexes []int
136-
types []sql.Type
137-
rows []sql.Row
124+
sortFields []SortField
125+
rows []sql.Row
138126
}
139127

140128
func (s *sorter) Len() int {
@@ -146,12 +134,12 @@ func (s *sorter) Swap(i, j int) {
146134
}
147135

148136
func (s *sorter) Less(i, j int) bool {
149-
a := s.rows[i].Fields()
150-
b := s.rows[j].Fields()
151-
for i, idx := range s.indexes {
152-
typ := s.types[i]
153-
av := a[idx]
154-
bv := b[idx]
137+
a := s.rows[i]
138+
b := s.rows[j]
139+
for _, sf := range s.sortFields {
140+
typ := sf.Column.Type()
141+
av := sf.Column.Eval(a)
142+
bv := sf.Column.Eval(b)
155143
if typ.Compare(av, bv) == -1 {
156144
return true
157145
}

sql/plan/sort_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66

77
"github.com/mvader/gitql/mem"
88
"github.com/mvader/gitql/sql"
9+
"github.com/mvader/gitql/sql/expression"
910
"github.com/stretchr/testify/assert"
1011
)
1112

@@ -20,8 +21,8 @@ func TestSort(t *testing.T) {
2021
child.Insert("b", int32(3))
2122
child.Insert("c", int32(1))
2223
sf := []SortField{
23-
{Column: "col2", Order: Ascending},
24-
{Column: "col1", Order: Descending},
24+
{Column: expression.NewGetField(1, sql.Integer, "col2"), Order: Ascending},
25+
{Column: expression.NewGetField(0, sql.String, "col1"), Order: Descending},
2526
}
2627
s := NewSort(sf, child)
2728
assert.Equal(childSchema, s.Schema())

0 commit comments

Comments
 (0)