Skip to content

Commit e96d391

Browse files
committed
SQLite: Fix convertion of UPDATE statements table name
Signed-off-by: Nathanael DEMACON <[email protected]>
1 parent b34aa37 commit e96d391

File tree

3 files changed

+74
-39
lines changed

3 files changed

+74
-39
lines changed

internal/endtoend/testdata/update_set/sqlite/go/query.sql.go

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
/* name: UpdateSet :exec */
22
UPDATE foo SET name = ? WHERE slug = ?;
3+
4+
/* name: UpdateSetQuoted :exec */
5+
UPDATE "foo" SET "name" = ? WHERE "slug" = ?;

internal/engine/sqlite/convert.go

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,8 @@ func (c *cc) convertTablesOrSubquery(n []parser.ITable_or_subqueryContext) []ast
10211021
}
10221022

10231023
type Update_stmt interface {
1024+
node
1025+
10241026
Qualified_table_name() parser.IQualified_table_nameContext
10251027
GetStart() antlr.Token
10261028
AllColumn_name() []parser.IColumn_nameContext
@@ -1034,50 +1036,66 @@ func (c *cc) convertUpdate_stmtContext(n Update_stmt) ast.Node {
10341036
return nil
10351037
}
10361038

1037-
relations := &ast.List{}
1038-
tableName := n.Qualified_table_name().GetText()
1039-
rel := ast.RangeVar{
1040-
Relname: &tableName,
1041-
Location: n.GetStart().GetStart(),
1042-
}
1043-
relations.Items = append(relations.Items, &rel)
1039+
if qualifiedName, ok := n.Qualified_table_name().(*parser.Qualified_table_nameContext); ok {
1040+
tableName := identifier(qualifiedName.Table_name().GetText())
1041+
rel := ast.RangeVar{
1042+
Relname: &tableName,
1043+
Location: n.GetStart().GetStart(),
1044+
}
10441045

1045-
list := &ast.List{}
1046-
for i, col := range n.AllColumn_name() {
1047-
colName := identifier(col.GetText())
1048-
target := &ast.ResTarget{
1049-
Name: &colName,
1050-
Val: c.convert(n.Expr(i)),
1046+
if qualifiedName.Schema_name() != nil {
1047+
schemaName := qualifiedName.Schema_name().GetText()
1048+
rel.Schemaname = &schemaName
10511049
}
1052-
list.Items = append(list.Items, target)
1053-
}
10541050

1055-
var where ast.Node = nil
1056-
if n.WHERE_() != nil {
1057-
where = c.convert(n.Expr(len(n.AllExpr()) - 1))
1058-
}
1051+
if qualifiedName.Alias() != nil {
1052+
alias := qualifiedName.Alias().GetText()
1053+
rel.Alias = &ast.Alias{Aliasname: &alias}
1054+
}
10591055

1060-
stmt := &ast.UpdateStmt{
1061-
Relations: relations,
1062-
TargetList: list,
1063-
WhereClause: where,
1064-
FromClause: &ast.List{},
1065-
WithClause: nil, // TODO: support with clause
1066-
}
1067-
if n, ok := n.(interface {
1068-
Returning_clause() parser.IReturning_clauseContext
1069-
}); ok {
1070-
stmt.ReturningList = c.convertReturning_caluseContext(n.Returning_clause())
1071-
} else {
1072-
stmt.ReturningList = c.convertReturning_caluseContext(nil)
1073-
}
1074-
if n, ok := n.(interface {
1075-
Limit_stmt() parser.ILimit_stmtContext
1076-
}); ok {
1077-
limitCount, _ := c.convertLimit_stmtContext(n.Limit_stmt())
1078-
stmt.LimitCount = limitCount
1056+
relations := &ast.List{}
1057+
1058+
relations.Items = append(relations.Items, &rel)
1059+
1060+
list := &ast.List{}
1061+
for i, col := range n.AllColumn_name() {
1062+
colName := identifier(col.GetText())
1063+
target := &ast.ResTarget{
1064+
Name: &colName,
1065+
Val: c.convert(n.Expr(i)),
1066+
}
1067+
list.Items = append(list.Items, target)
1068+
}
1069+
1070+
var where ast.Node = nil
1071+
if n.WHERE_() != nil {
1072+
where = c.convert(n.Expr(len(n.AllExpr()) - 1))
1073+
}
1074+
1075+
stmt := &ast.UpdateStmt{
1076+
Relations: relations,
1077+
TargetList: list,
1078+
WhereClause: where,
1079+
FromClause: &ast.List{},
1080+
WithClause: nil, // TODO: support with clause
1081+
}
1082+
if n, ok := n.(interface {
1083+
Returning_clause() parser.IReturning_clauseContext
1084+
}); ok {
1085+
stmt.ReturningList = c.convertReturning_caluseContext(n.Returning_clause())
1086+
} else {
1087+
stmt.ReturningList = c.convertReturning_caluseContext(nil)
1088+
}
1089+
if n, ok := n.(interface {
1090+
Limit_stmt() parser.ILimit_stmtContext
1091+
}); ok {
1092+
limitCount, _ := c.convertLimit_stmtContext(n.Limit_stmt())
1093+
stmt.LimitCount = limitCount
1094+
}
1095+
return stmt
10791096
}
1080-
return stmt
1097+
1098+
return todo("convertUpdate_stmtContext", n)
10811099
}
10821100

10831101
func (c *cc) convertBetweenExpr(n *parser.Expr_betweenContext) ast.Node {

0 commit comments

Comments
 (0)