Skip to content

Commit f9d952d

Browse files
cmoogkyleconroy
authored andcommitted
adds support for custom sqlc.arg(MyParam) params
1 parent 014ce72 commit f9d952d

File tree

6 files changed

+128
-17
lines changed

6 files changed

+128
-17
lines changed

examples/booktest/mysql/query.sql

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ WHERE book_id = ?;
4343

4444
/* name: UpdateBookISBN :exec */
4545
UPDATE books
46-
SET title = ?, tags = ?, isbn = ?
46+
SET title = ?, tags = :book_tags, isbn = ?
4747
WHERE book_id = ?;
4848

49+
/* name: DeleteAuthorBeforeYear :exec */
50+
DELETE FROM books
51+
WHERE yr < sqlc.arg(min_publish_year) AND author_id = ?;

examples/booktest/mysql/query.sql.go

Lines changed: 19 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/mysql/param.go

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ func paramsInLimitExpr(limit *sqlparser.Limit, s *Schema, tableAliasMap FromTabl
2323
return params, nil
2424
}
2525

26-
parseLimitSubExp := func(node sqlparser.Expr) {
26+
parseLimitSubExp := func(node sqlparser.Expr) error {
2727
switch v := node.(type) {
2828
case *sqlparser.SQLVal:
2929
if v.Type == sqlparser.ValArg {
@@ -33,11 +33,30 @@ func paramsInLimitExpr(limit *sqlparser.Limit, s *Schema, tableAliasMap FromTabl
3333
Typ: "uint32",
3434
})
3535
}
36+
case *sqlparser.FuncExpr:
37+
name, raw, err := matchFuncExpr(v)
38+
if err != nil {
39+
return err
40+
}
41+
if name != "" && raw != "" {
42+
params = append(params, &Param{
43+
OriginalName: raw,
44+
Name: name,
45+
Typ: "uint32",
46+
})
47+
}
3648
}
49+
return nil
3750
}
3851

39-
parseLimitSubExp(limit.Offset)
40-
parseLimitSubExp(limit.Rowcount)
52+
err := parseLimitSubExp(limit.Offset)
53+
if err != nil {
54+
return nil, err
55+
}
56+
err = parseLimitSubExp(limit.Rowcount)
57+
if err != nil {
58+
return nil, err
59+
}
4160

4261
return params, nil
4362
}
@@ -115,13 +134,26 @@ func paramInComparison(cond *sqlparser.ComparisonExpr, s *Schema, tableAliasMap
115134
if v.Type == sqlparser.ValArg {
116135
p.OriginalName = string(v.Val)
117136
}
137+
case *sqlparser.FuncExpr:
138+
name, raw, err := matchFuncExpr(v)
139+
if err != nil {
140+
return false, err
141+
}
142+
if name != "" && raw != "" {
143+
p.OriginalName = raw
144+
p.Name = name
145+
}
146+
return false, nil
118147
}
119148
return true, nil
120149
}
121150
err := sqlparser.Walk(walker, cond)
122151
if err != nil {
123152
return nil, false, err
124153
}
154+
if p.Name != "" {
155+
return p, true, nil
156+
}
125157
if p.OriginalName != "" && p.Typ != "" {
126158
p.Name = paramName(colIdent, p.OriginalName)
127159
return p, true, nil
@@ -143,11 +175,39 @@ func paramName(col sqlparser.ColIdent, originalName string) string {
143175

144176
func replaceParamStrs(query string, params []*Param) (string, error) {
145177
for _, p := range params {
146-
re, err := regexp.Compile(fmt.Sprintf("(%v)", p.OriginalName))
178+
re, err := regexp.Compile(fmt.Sprintf("(%v)", regexp.QuoteMeta(p.OriginalName)))
147179
if err != nil {
148180
return "", err
149181
}
150182
query = re.ReplaceAllString(query, "?")
151183
}
152184
return query, nil
153185
}
186+
187+
func matchFuncExpr(v *sqlparser.FuncExpr) (name string, raw string, err error) {
188+
namespace := "sqlc"
189+
fakeFunc := "arg"
190+
if v.Qualifier.String() == namespace {
191+
if v.Name.String() == fakeFunc {
192+
if expr, ok := v.Exprs[0].(*sqlparser.AliasedExpr); ok {
193+
if colName, ok := expr.Expr.(*sqlparser.ColName); ok {
194+
customName := colName.Name.String()
195+
return customName, fmt.Sprintf("%s.%s(%s)", namespace, fakeFunc, customName), nil
196+
}
197+
return "", "", fmt.Errorf("invalid custom argument value \"%s.%s(%s)\"", namespace, fakeFunc, replaceVParamExprs(sqlparser.String(v.Exprs[0])))
198+
}
199+
return "", "", fmt.Errorf("invalid custom argument value \"%s.%s(%s)\"", namespace, fakeFunc, replaceVParamExprs(sqlparser.String(v.Exprs[0])))
200+
}
201+
return "", "", fmt.Errorf("invalid function call \"%s.%s\", did you mean \"%s.%s\"?", namespace, v.Name.String(), namespace, fakeFunc)
202+
}
203+
return "", "", nil
204+
}
205+
206+
func replaceVParamExprs(sql string) string {
207+
/*
208+
the sqlparser replaces "?" with ":v1"
209+
to display a helpful error message, these should be replaced back to "?"
210+
*/
211+
matcher := regexp.MustCompile(":v[0-9]*")
212+
return matcher.ReplaceAllString(sql, "?")
213+
}

internal/mysql/param_test.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,16 @@ func TestSelectParamSearcher(t *testing.T) {
7575
},
7676
},
7777
},
78+
{
79+
input: "select first_name, id FROM users LIMIT sqlc.arg(UsersLimit)",
80+
output: []*Param{
81+
&Param{
82+
OriginalName: "sqlc.arg(UsersLimit)",
83+
Name: "UsersLimit",
84+
Typ: "uint32",
85+
},
86+
},
87+
},
7888
}
7989
for _, tCase := range tests {
8090
tree, err := sqlparser.Parse(tCase.input)
@@ -118,20 +128,20 @@ func TestInsertParamSearcher(t *testing.T) {
118128

119129
tests := []testCase{
120130
testCase{
121-
input: "/* name: InsertNewUser :exec */\nINSERT INTO users (first_name, last_name) VALUES (?, ?)",
131+
input: "/* name: InsertNewUser :exec */\nINSERT INTO users (first_name, last_name) VALUES (?, sqlc.arg(user_last_name))",
122132
output: []*Param{
123133
&Param{
124134
OriginalName: ":v1",
125135
Name: "first_name",
126136
Typ: "string",
127137
},
128138
&Param{
129-
OriginalName: ":v2",
130-
Name: "last_name",
139+
OriginalName: "sqlc.arg(user_last_name)",
140+
Name: "user_last_name",
131141
Typ: "sql.NullString",
132142
},
133143
},
134-
expectedNames: []string{"first_name", "last_name"},
144+
expectedNames: []string{"first_name", "user_last_name"},
135145
},
136146
}
137147
for _, tCase := range tests {

internal/mysql/parse.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,9 +345,30 @@ func parseInsert(node *sqlparser.Insert, query string, s *Schema, settings dinos
345345
}
346346
params = append(params, p)
347347
}
348+
case *sqlparser.FuncExpr:
349+
name, raw, err := matchFuncExpr(v)
348350

351+
if err != nil {
352+
return nil, err
353+
}
354+
if name == "" || raw == "" {
355+
continue
356+
}
357+
colName := cols[colIx].String()
358+
colDfn, err := s.schemaLookup(tableName, colName)
359+
p := &Param{
360+
OriginalName: raw,
361+
}
362+
if err == nil {
363+
p.Name = name
364+
p.Typ = goTypeCol(colDfn, settings)
365+
} else {
366+
p.Name = "Unknown"
367+
p.Typ = "interface{}"
368+
}
369+
params = append(params, p)
349370
default:
350-
panic("Error occurred in parsing INSERT statement")
371+
return nil, fmt.Errorf("failed to parse insert query value")
351372
}
352373
}
353374
}

internal/mysql/schema.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ type Schema struct {
2222
// returns a deep copy of the column definition for using as a query return type or param type
2323
func (s *Schema) getColType(col *sqlparser.ColName, tableAliasMap FromTables, defaultTableName string) (*sqlparser.ColumnDefinition, error) {
2424
realTable, err := tableColReferences(col, defaultTableName, tableAliasMap)
25+
if err != nil {
26+
return nil, err
27+
}
2528

2629
colDfn, err := s.schemaLookup(realTable.TrueName, col.Name.String())
2730
if err != nil {
@@ -38,13 +41,13 @@ func tableColReferences(col *sqlparser.ColName, defaultTable string, tableAliasM
3841
var table FromTable
3942
if col.Qualifier.IsEmpty() {
4043
if defaultTable == "" {
41-
return FromTable{}, fmt.Errorf("Column reference [%v] is ambiguous -- Add a qualifier", col.Name.String())
44+
return FromTable{}, fmt.Errorf("column reference \"%s\" is ambiguous, add a qualifier", col.Name.String())
4245
}
4346
table = FromTable{defaultTable, false}
4447
} else {
4548
fromTable, ok := tableAliasMap[col.Qualifier.Name.String()]
4649
if !ok {
47-
return FromTable{}, fmt.Errorf("Column qualifier [%v] not found in table alias map", col.Qualifier.Name.String())
50+
return FromTable{}, fmt.Errorf("column qualifier \"%s\" is not in schema or is an invalid alias", col.Qualifier.Name.String())
4851
}
4952
return fromTable, nil
5053
}

0 commit comments

Comments
 (0)