Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit 9ee3f1d

Browse files
authored
Merge pull request #555 from kuba--/fix-546/null
Implement ifnull and nullif functions.
2 parents 5a108b2 + 88d2d71 commit 9ee3f1d

File tree

9 files changed

+366
-4
lines changed

9 files changed

+366
-4
lines changed

README.md

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,21 @@ go get gopkg.in/src-d/go-mysql-server.v0
5050

5151
We are continuously adding more functionality to go-mysql-server. We support a subset of what is supported in MySQL, to see what is currently included check the [SUPPORTED](./SUPPORTED.md) file.
5252

53-
# Third-party clients
53+
## Third-party clients
5454

5555
We support and actively test against certain third-party clients to ensure compatibility between them and go-mysql-server. You can check out the list of supported third party clients in the [SUPPORTED_CLIENTS](./SUPPORTED_CLIENTS.md) file along with some examples on how to connect to go-mysql-server using them.
5656

5757
## Custom functions
5858

59+
- `COUNT(expr)`: Returns a count of the number of non-NULL values of expr in the rows retrieved by a SELECT statement.
60+
- `MIN(expr)`: Returns the minimum value of expr.
61+
- `MAX(expr)`: Returns the maximum value of expr.
62+
- `AVG(expr)`: Returns the average value of expr.
63+
- `SUM(expr)`: Returns the sum of expr.
5964
- `IS_BINARY(blob)`: Returns whether a BLOB is a binary file or not.
60-
- `SUBSTRING(str, pos)`, `SUBSTRING(str, pos, len)`: Return a substring from the provided string.
65+
- `SUBSTRING(str, pos)`, `SUBSTRING(str, pos, len)` : Return a substring from the provided string.
66+
- `SUBSTR(str, pos)`, `SUBSTR(str, pos, len)` : Return a substring from the provided string.
67+
- `MID(str, pos)`, `MID(str, pos, len)` : Return a substring from the provided string.
6168
- Date and Timestamp functions: `YEAR(date)`, `MONTH(date)`, `DAY(date)`, `WEEKDAY(date)`, `HOUR(date)`, `MINUTE(date)`, `SECOND(date)`, `DAYOFWEEK(date)`, `DAYOFYEAR(date)`.
6269
- `ARRAY_LENGTH(json)`: If the json representation is an array, this function returns its size.
6370
- `SPLIT(str,sep)`: Receives a string and a separator and returns the parts of the string split by the separator as a JSON array of strings.
@@ -70,6 +77,23 @@ We support and actively test against certain third-party clients to ensure compa
7077
- `ROUND(number, decimals)`: Round the `number` to `decimals` decimal places.
7178
- `CONNECTION_ID()`: Return the current connection ID.
7279
- `SOUNDEX(str)`: Returns the soundex of a string.
80+
- `JSON_EXTRACT(json_doc, path, ...)`: Extracts data from a json document using json paths.
81+
- `LN(X)`: Return the natural logarithm of X.
82+
- `LOG2(X)`: Returns the base-2 logarithm of X.
83+
- `LOG10(X)`: Returns the base-10 logarithm of X.
84+
- `LOG(X), LOG(B, X)`: If called with one parameter, this function returns the natural logarithm of X. If called with two parameters, this function returns the logarithm of X to the base B. If X is less than or equal to 0, or if B is less than or equal to 1, then NULL is returned.
85+
- `RPAD(str, len, padstr)`: Returns the string str, right-padded with the string padstr to a length of len characters.
86+
- `LPAD(str, len, padstr)`: Return the string argument, left-padded with the specified string.
87+
- `SQRT(X)`: Returns the square root of a nonnegative number X.
88+
- `POW(X, Y)`, `POWER(X, Y)`: Returns the value of X raised to the power of Y.
89+
- `TRIM(str)`: Returns the string str with all spaces removed.
90+
- `LTRIM(str)`: Returns the string str with leading space characters removed.
91+
- `RTRIM(str)`: Returns the string str with trailing space characters removed.
92+
- `REVERSE(str)`: Returns the string str with the order of the characters reversed.
93+
- `REPEAT(str, count)`: Returns a string consisting of the string str repeated count times.
94+
- `REPLACE(str,from_str,to_str)`: Returns the string str with all occurrences of the string from_str replaced by the string to_str.
95+
- `IFNULL(expr1, expr2)`: If expr1 is not NULL, IFNULL() returns expr1; otherwise it returns expr2.
96+
- `NULLIF(expr1, expr2)`: Returns NULL if expr1 = expr2 is true, otherwise returns expr1.
7397

7498
## Example
7599

engine_test.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,72 @@ var queries = []struct {
700700
{"tabletest", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil},
701701
},
702702
},
703+
{
704+
`SELECT NULL`,
705+
[]sql.Row{
706+
{nil},
707+
},
708+
},
709+
{
710+
`SELECT nullif('abc', NULL)`,
711+
[]sql.Row{
712+
{"abc"},
713+
},
714+
},
715+
{
716+
`SELECT nullif(NULL, NULL)`,
717+
[]sql.Row{
718+
{sql.Null},
719+
},
720+
},
721+
{
722+
`SELECT nullif(NULL, 123)`,
723+
[]sql.Row{
724+
{nil},
725+
},
726+
},
727+
{
728+
`SELECT nullif(123, 123)`,
729+
[]sql.Row{
730+
{sql.Null},
731+
},
732+
},
733+
{
734+
`SELECT nullif(123, 321)`,
735+
[]sql.Row{
736+
{int64(123)},
737+
},
738+
},
739+
{
740+
`SELECT ifnull(123, NULL)`,
741+
[]sql.Row{
742+
{int64(123)},
743+
},
744+
},
745+
{
746+
`SELECT ifnull(NULL, NULL)`,
747+
[]sql.Row{
748+
{nil},
749+
},
750+
},
751+
{
752+
`SELECT ifnull(NULL, 123)`,
753+
[]sql.Row{
754+
{int64(123)},
755+
},
756+
},
757+
{
758+
`SELECT ifnull(123, 123)`,
759+
[]sql.Row{
760+
{int64(123)},
761+
},
762+
},
763+
{
764+
`SELECT ifnull(123, 321)`,
765+
[]sql.Row{
766+
{int64(123)},
767+
},
768+
},
703769
}
704770

705771
func TestQueries(t *testing.T) {

sql/expression/function/ifnull.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
package function
2+
3+
import (
4+
"fmt"
5+
6+
"gopkg.in/src-d/go-mysql-server.v0/sql"
7+
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
8+
)
9+
10+
// IfNull function returns the specified value IF the expression is NULL, otherwise return the expression.
11+
type IfNull struct {
12+
expression.BinaryExpression
13+
}
14+
15+
// NewIfNull returns a new IFNULL UDF
16+
func NewIfNull(ex, value sql.Expression) sql.Expression {
17+
return &IfNull{
18+
expression.BinaryExpression{
19+
Left: ex,
20+
Right: value,
21+
},
22+
}
23+
}
24+
25+
// Eval implements the Expression interface.
26+
func (f *IfNull) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
27+
left, err := f.Left.Eval(ctx, row)
28+
if err != nil {
29+
return nil, err
30+
}
31+
if left != nil {
32+
return left, nil
33+
}
34+
35+
right, err := f.Right.Eval(ctx, row)
36+
if err != nil {
37+
return nil, err
38+
}
39+
return right, nil
40+
}
41+
42+
// Type implements the Expression interface.
43+
func (f *IfNull) Type() sql.Type {
44+
if sql.IsNull(f.Left) {
45+
if sql.IsNull(f.Right) {
46+
return sql.Null
47+
}
48+
return f.Right.Type()
49+
}
50+
return f.Left.Type()
51+
}
52+
53+
// IsNullable implements the Expression interface.
54+
func (f *IfNull) IsNullable() bool {
55+
if sql.IsNull(f.Left) {
56+
if sql.IsNull(f.Right) {
57+
return true
58+
}
59+
return f.Right.IsNullable()
60+
}
61+
return f.Left.IsNullable()
62+
}
63+
64+
func (f *IfNull) String() string {
65+
return fmt.Sprintf("ifnull(%s, %s)", f.Left, f.Right)
66+
}
67+
68+
// TransformUp implements the Expression interface.
69+
func (f *IfNull) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) {
70+
left, err := f.Left.TransformUp(fn)
71+
if err != nil {
72+
return nil, err
73+
}
74+
75+
right, err := f.Right.TransformUp(fn)
76+
if err != nil {
77+
return nil, err
78+
}
79+
80+
return fn(NewIfNull(left, right))
81+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package function
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
"gopkg.in/src-d/go-mysql-server.v0/sql"
8+
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
9+
)
10+
11+
func TestIfNull(t *testing.T) {
12+
testCases := []struct {
13+
expression interface{}
14+
value interface{}
15+
expected interface{}
16+
}{
17+
{"foo", "bar", "foo"},
18+
{"foo", "foo", "foo"},
19+
{nil, "foo", "foo"},
20+
{"foo", nil, "foo"},
21+
{nil, nil, nil},
22+
{"", nil, ""},
23+
}
24+
25+
f := NewIfNull(
26+
expression.NewGetField(0, sql.Text, "expression", true),
27+
expression.NewGetField(1, sql.Text, "value", true),
28+
)
29+
require.Equal(t, sql.Text, f.Type())
30+
31+
for _, tc := range testCases {
32+
v, err := f.Eval(sql.NewEmptyContext(), sql.NewRow(tc.expression, tc.value))
33+
require.NoError(t, err)
34+
require.Equal(t, tc.expected, v)
35+
}
36+
}

sql/expression/function/nullif.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package function
2+
3+
import (
4+
"fmt"
5+
6+
"gopkg.in/src-d/go-mysql-server.v0/sql"
7+
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
8+
)
9+
10+
// NullIf function compares two expressions and returns NULL if they are equal. Otherwise, the first expression is returned.
11+
type NullIf struct {
12+
expression.BinaryExpression
13+
}
14+
15+
// NewNullIf returns a new NULLIF UDF
16+
func NewNullIf(ex1, ex2 sql.Expression) sql.Expression {
17+
return &NullIf{
18+
expression.BinaryExpression{
19+
Left: ex1,
20+
Right: ex2,
21+
},
22+
}
23+
}
24+
25+
// Eval implements the Expression interface.
26+
func (f *NullIf) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
27+
if sql.IsNull(f.Left) && sql.IsNull(f.Right) {
28+
return sql.Null, nil
29+
}
30+
31+
val, err := expression.NewEquals(f.Left, f.Right).Eval(ctx, row)
32+
if err != nil {
33+
return nil, err
34+
}
35+
if b, ok := val.(bool); ok && b {
36+
return sql.Null, nil
37+
}
38+
39+
return f.Left.Eval(ctx, row)
40+
}
41+
42+
// Type implements the Expression interface.
43+
func (f *NullIf) Type() sql.Type {
44+
if sql.IsNull(f.Left) {
45+
return sql.Null
46+
}
47+
48+
return f.Left.Type()
49+
}
50+
51+
// IsNullable implements the Expression interface.
52+
func (f *NullIf) IsNullable() bool {
53+
return true
54+
}
55+
56+
func (f *NullIf) String() string {
57+
return fmt.Sprintf("nullif(%s, %s)", f.Left, f.Right)
58+
}
59+
60+
// TransformUp implements the Expression interface.
61+
func (f *NullIf) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) {
62+
left, err := f.Left.TransformUp(fn)
63+
if err != nil {
64+
return nil, err
65+
}
66+
67+
right, err := f.Right.TransformUp(fn)
68+
if err != nil {
69+
return nil, err
70+
}
71+
72+
return fn(NewNullIf(left, right))
73+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package function
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
"gopkg.in/src-d/go-mysql-server.v0/sql"
8+
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
9+
)
10+
11+
func TestNullIf(t *testing.T) {
12+
testCases := []struct {
13+
ex1 interface{}
14+
ex2 interface{}
15+
expected interface{}
16+
}{
17+
{"foo", "bar", "foo"},
18+
{"foo", "foo", sql.Null},
19+
{nil, "foo", nil},
20+
{"foo", nil, "foo"},
21+
{nil, nil, nil},
22+
{"", nil, ""},
23+
}
24+
25+
f := NewNullIf(
26+
expression.NewGetField(0, sql.Text, "ex1", true),
27+
expression.NewGetField(1, sql.Text, "ex2", true),
28+
)
29+
require.Equal(t, sql.Text, f.Type())
30+
31+
for _, tc := range testCases {
32+
v, err := f.Eval(sql.NewEmptyContext(), sql.NewRow(tc.ex1, tc.ex2))
33+
require.NoError(t, err)
34+
require.Equal(t, tc.expected, v)
35+
}
36+
}

sql/expression/function/registry.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,16 @@ var Defaults = sql.Functions{
4141
"split": sql.Function2(NewSplit),
4242
"concat": sql.FunctionN(NewConcat),
4343
"concat_ws": sql.FunctionN(NewConcatWithSeparator),
44+
"coalesce": sql.FunctionN(NewCoalesce),
4445
"lower": sql.Function1(NewLower),
4546
"upper": sql.Function1(NewUpper),
4647
"ceiling": sql.Function1(NewCeil),
4748
"ceil": sql.Function1(NewCeil),
4849
"floor": sql.Function1(NewFloor),
4950
"round": sql.FunctionN(NewRound),
50-
"coalesce": sql.FunctionN(NewCoalesce),
51-
"json_extract": sql.FunctionN(NewJSONExtract),
5251
"connection_id": sql.Function0(NewConnectionID),
5352
"soundex": sql.Function1(NewSoundex),
53+
"json_extract": sql.FunctionN(NewJSONExtract),
5454
"ln": sql.Function1(NewLogBaseFunc(float64(math.E))),
5555
"log2": sql.Function1(NewLogBaseFunc(float64(2))),
5656
"log10": sql.Function1(NewLogBaseFunc(float64(10))),
@@ -66,4 +66,6 @@ var Defaults = sql.Functions{
6666
"reverse": sql.Function1(NewReverse),
6767
"repeat": sql.Function2(NewRepeat),
6868
"replace": sql.Function3(NewReplace),
69+
"ifnull": sql.Function2(NewIfNull),
70+
"nullif": sql.Function2(NewNullIf),
6971
}

0 commit comments

Comments
 (0)