Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit 35beff6

Browse files
committed
Add Reverse, Repeat, Replace
Signed-off-by: Theo Despoudis <[email protected]>
1 parent 124343f commit 35beff6

File tree

3 files changed

+365
-0
lines changed

3 files changed

+365
-0
lines changed

sql/expression/function/registry.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,7 @@ var Defaults = sql.Functions{
6161
"ltrim": sql.Function1(NewTrimFunc(lTrimType)),
6262
"rtrim": sql.Function1(NewTrimFunc(rTrimType)),
6363
"trim": sql.Function1(NewTrimFunc(bTrimType)),
64+
"reverse": sql.Function1(NewReverse),
65+
"repeat": sql.Function2(NewRepeat),
66+
"replace": sql.Function3(NewReplace),
6467
}
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
package function
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
7+
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
8+
"gopkg.in/src-d/go-mysql-server.v0/sql"
9+
10+
"gopkg.in/src-d/go-errors.v1"
11+
)
12+
13+
// Reverse is a function that returns the reverse of the text provided.
14+
type Reverse struct {
15+
expression.UnaryExpression
16+
}
17+
18+
// NewLower creates a new Lower expression.
19+
func NewReverse(e sql.Expression) sql.Expression {
20+
return &Reverse{expression.UnaryExpression{Child: e}}
21+
}
22+
23+
// Eval implements the Expression interface.
24+
func (r *Reverse) Eval(
25+
ctx *sql.Context,
26+
row sql.Row,
27+
) (interface{}, error) {
28+
v, err := r.Child.Eval(ctx, row)
29+
if err != nil {
30+
return nil, err
31+
}
32+
33+
if v == nil {
34+
return nil, nil
35+
}
36+
37+
v, err = sql.Text.Convert(v)
38+
if err != nil {
39+
return nil, err
40+
}
41+
42+
return reverseString(v.(string)), nil
43+
}
44+
45+
func reverseString(s string) string {
46+
r := []rune(s)
47+
for i, j := 0, len(r)-1; i < len(r)/2; i, j = i+1, j-1 {
48+
r[i], r[j] = r[j], r[i]
49+
}
50+
return string(r)
51+
}
52+
53+
func (r *Reverse) String() string {
54+
return fmt.Sprintf("reverse(%s)", r.Child)
55+
}
56+
57+
// TransformUp implements the Expression interface.
58+
func (r *Reverse) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) {
59+
child, err := r.Child.TransformUp(f)
60+
if err != nil {
61+
return nil, err
62+
}
63+
return f(NewReverse(child))
64+
}
65+
66+
// Type implements the Expression interface.
67+
func (r *Reverse) Type() sql.Type {
68+
return r.Child.Type()
69+
}
70+
71+
var ErrNegativeRepeatCount = errors.NewKind("negative Repeat count: %v")
72+
73+
// Repeat is a function that returns the string repeated n times.
74+
type Repeat struct {
75+
expression.BinaryExpression
76+
}
77+
78+
// NewRepeat creates a new Repeat expression.
79+
func NewRepeat(str sql.Expression, count sql.Expression) sql.Expression {
80+
return &Repeat{expression.BinaryExpression{Left: str, Right: count}}
81+
}
82+
83+
func (r *Repeat) String() string {
84+
return fmt.Sprintf("repeat(%s, %s)", r.Left, r.Right)
85+
}
86+
87+
// Type implements the Expression interface.
88+
func (r *Repeat) Type() sql.Type {
89+
return sql.Text
90+
}
91+
92+
// TransformUp implements the Expression interface.
93+
func (r *Repeat) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) {
94+
left, err := r.Left.TransformUp(f)
95+
if err != nil {
96+
return nil, err
97+
}
98+
99+
right, err := r.Right.TransformUp(f)
100+
if err != nil {
101+
return nil, err
102+
}
103+
return f(NewRepeat(left, right))
104+
}
105+
106+
// Eval implements the Expression interface.
107+
func (r *Repeat) Eval(
108+
ctx *sql.Context,
109+
row sql.Row,
110+
) (interface{}, error) {
111+
str, err := r.Left.Eval(ctx, row)
112+
if err != nil {
113+
return nil, err
114+
}
115+
116+
if str == nil {
117+
return nil, nil
118+
}
119+
120+
str, err = sql.Text.Convert(str)
121+
if err != nil {
122+
return nil, err
123+
}
124+
125+
count, err := r.Right.Eval(ctx, row)
126+
if err != nil {
127+
return nil, err
128+
}
129+
130+
if count == nil {
131+
return nil, nil
132+
}
133+
134+
count, err = sql.Int32.Convert(count)
135+
if err != nil {
136+
return nil, err
137+
}
138+
if count.(int32) < 0 {
139+
return nil, ErrNegativeRepeatCount.New(count)
140+
}
141+
return strings.Repeat(str.(string), int(count.(int32))), nil
142+
}
143+
144+
// Replace is a function that returns a string with all occurrences of fromStr replaced by the
145+
// string toStr
146+
type Replace struct {
147+
str sql.Expression
148+
fromStr sql.Expression
149+
toStr sql.Expression
150+
}
151+
152+
// NewReplace creates a new Replace expression.
153+
func NewReplace(str sql.Expression, fromStr sql.Expression, toStr sql.Expression) sql.Expression {
154+
return &Replace{str, fromStr, toStr}
155+
}
156+
157+
// Children implements the Expression interface.
158+
func (r *Replace) Children() []sql.Expression {
159+
return []sql.Expression{r.str, r.fromStr, r.toStr}
160+
}
161+
162+
// Resolved implements the Expression interface.
163+
func (r *Replace) Resolved() bool {
164+
return r.str.Resolved() && r.fromStr.Resolved() && r.toStr.Resolved()
165+
}
166+
167+
// IsNullable implements the Expression interface.
168+
func (r *Replace) IsNullable() bool {
169+
return r.str.IsNullable() || r.fromStr.IsNullable() || r.toStr.IsNullable()
170+
}
171+
172+
func (r *Replace) String() string {
173+
return fmt.Sprintf("replace(%s, %s, %s)", r.str, r.fromStr, r.toStr)
174+
}
175+
176+
// Type implements the Expression interface.
177+
func (r *Replace) Type() sql.Type {
178+
return sql.Text
179+
}
180+
181+
// TransformUp implements the Expression interface.
182+
func (r *Replace) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) {
183+
str, err := r.str.TransformUp(f)
184+
if err != nil {
185+
return nil, err
186+
}
187+
188+
fromStr, err := r.fromStr.TransformUp(f)
189+
if err != nil {
190+
return nil, err
191+
}
192+
193+
toStr, err := r.toStr.TransformUp(f)
194+
if err != nil {
195+
return nil, err
196+
}
197+
return f(NewReplace(str, fromStr, toStr))
198+
}
199+
200+
// Eval implements the Expression interface.
201+
func (r *Replace) Eval(
202+
ctx *sql.Context,
203+
row sql.Row,
204+
) (interface{}, error) {
205+
str, err := r.str.Eval(ctx, row)
206+
if err != nil {
207+
return nil, err
208+
}
209+
210+
if str == nil {
211+
return nil, nil
212+
}
213+
214+
str, err = sql.Text.Convert(str)
215+
if err != nil {
216+
return nil, err
217+
}
218+
219+
fromStr, err := r.fromStr.Eval(ctx, row)
220+
if err != nil {
221+
return nil, err
222+
}
223+
224+
if fromStr == nil {
225+
return nil, nil
226+
}
227+
228+
fromStr, err = sql.Text.Convert(fromStr)
229+
if err != nil {
230+
return nil, err
231+
}
232+
233+
toStr, err := r.toStr.Eval(ctx, row)
234+
if err != nil {
235+
return nil, err
236+
}
237+
238+
if toStr == nil {
239+
return nil, nil
240+
}
241+
242+
toStr, err = sql.Text.Convert(toStr)
243+
if err != nil {
244+
return nil, err
245+
}
246+
247+
if fromStr.(string) == "" {
248+
return str, nil
249+
}
250+
251+
return strings.Replace(str.(string), fromStr.(string), toStr.(string), -1), nil
252+
}
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package function
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
"gopkg.in/src-d/go-mysql-server.v0/sql"
8+
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
9+
)
10+
11+
func TestReverse(t *testing.T) {
12+
f := NewReverse(expression.NewGetField(0, sql.Text, "", false))
13+
testCases := []struct {
14+
name string
15+
row sql.Row
16+
expected interface{}
17+
err bool
18+
}{
19+
{"null input", sql.NewRow(nil), nil, false},
20+
{"empty string", sql.NewRow(""), "", false},
21+
{"handles numbers as strings", sql.NewRow(123), "321", false},
22+
{"valid string", sql.NewRow("foobar"), "raboof", false},
23+
}
24+
for _, tt := range testCases {
25+
t.Run(tt.name, func(t *testing.T) {
26+
t.Helper()
27+
require := require.New(t)
28+
ctx := sql.NewEmptyContext()
29+
30+
v, err := f.Eval(ctx, tt.row)
31+
if tt.err {
32+
require.Error(err)
33+
} else {
34+
require.NoError(err)
35+
require.Equal(tt.expected, v)
36+
}
37+
})
38+
}
39+
}
40+
41+
func TestRepeat(t *testing.T) {
42+
f := NewRepeat(
43+
expression.NewGetField(0, sql.Text, "", false),
44+
expression.NewGetField(1, sql.Int32, "", false),
45+
)
46+
47+
testCases := []struct {
48+
name string
49+
row sql.Row
50+
expected interface{}
51+
err bool
52+
}{
53+
{"null input", sql.NewRow(nil), nil, false},
54+
{"empty string", sql.NewRow("", 2), "", false},
55+
{"count is zero", sql.NewRow("foo", 0), "", false},
56+
{"count is negative", sql.NewRow("foo", -2), "foo", true},
57+
{"valid string", sql.NewRow("foobar", 2), "foobarfoobar", false},
58+
}
59+
for _, tt := range testCases {
60+
t.Run(tt.name, func(t *testing.T) {
61+
t.Helper()
62+
require := require.New(t)
63+
ctx := sql.NewEmptyContext()
64+
65+
v, err := f.Eval(ctx, tt.row)
66+
if tt.err {
67+
require.Error(err)
68+
} else {
69+
require.NoError(err)
70+
require.Equal(tt.expected, v)
71+
}
72+
})
73+
}
74+
}
75+
76+
func TestReplace(t *testing.T) {
77+
f := NewReplace(
78+
expression.NewGetField(0, sql.Text, "", false),
79+
expression.NewGetField(1, sql.Text, "", false),
80+
expression.NewGetField(2, sql.Text, "", false),
81+
)
82+
83+
testCases := []struct {
84+
name string
85+
row sql.Row
86+
expected interface{}
87+
err bool
88+
}{
89+
{"null inputs", sql.NewRow(nil), nil, false},
90+
{"empty str", sql.NewRow("", "foo", "bar"), "", false},
91+
{"empty fromStr", sql.NewRow("foobarfoobar", "", "car"), "foobarfoobar", false},
92+
{"empty toStr", sql.NewRow("foobarfoobar", "bar", ""), "foofoo", false},
93+
{"valid strings", sql.NewRow("foobarfoobar", "bar", "car"), "foocarfoocar", false},
94+
}
95+
for _, tt := range testCases {
96+
t.Run(tt.name, func(t *testing.T) {
97+
t.Helper()
98+
require := require.New(t)
99+
ctx := sql.NewEmptyContext()
100+
101+
v, err := f.Eval(ctx, tt.row)
102+
if tt.err {
103+
require.Error(err)
104+
} else {
105+
require.NoError(err)
106+
require.Equal(tt.expected, v)
107+
}
108+
})
109+
}
110+
}

0 commit comments

Comments
 (0)