diff --git a/internal/engine/sqlite/catalog_test.go b/internal/engine/sqlite/catalog_test.go index bf6dcd8316..a4f2861838 100644 --- a/internal/engine/sqlite/catalog_test.go +++ b/internal/engine/sqlite/catalog_test.go @@ -82,6 +82,30 @@ func TestUpdate(t *testing.T) { }, }, }, + { + ` + CREATE TABLE foo (bar text); + ALTER TABLE foo ADD COLUMN baz; + `, + &catalog.Schema{ + Name: "main", + Tables: []*catalog.Table{ + { + Rel: &ast.TableName{Name: "foo"}, + Columns: []*catalog.Column{ + { + Name: "bar", + Type: ast.TypeName{Name: "text"}, + }, + { + Name: "baz", + Type: ast.TypeName{Name: "any"}, + }, + }, + }, + }, + }, + }, { ` CREATE TABLE foo (bar text); diff --git a/internal/engine/sqlite/convert.go b/internal/engine/sqlite/convert.go index 02d80bc48c..d97977bbb0 100644 --- a/internal/engine/sqlite/convert.go +++ b/internal/engine/sqlite/convert.go @@ -35,6 +35,13 @@ func identifier(id string) string { return strings.ToLower(id) } +func getTypeName(t parser.IType_nameContext) string { + if t == nil { + return "any" + } + return t.GetText() +} + func NewIdentifier(t string) *ast.String { return &ast.String{Str: identifier(t)} } @@ -72,10 +79,8 @@ func (c *cc) convertAlter_table_stmtContext(n *parser.Alter_table_stmtContext) a Name: &name, Subtype: ast.AT_AddColumn, Def: &ast.ColumnDef{ - Colname: name, - TypeName: &ast.TypeName{ - Name: def.Type_name().GetText(), - }, + Colname: name, + TypeName: &ast.TypeName{Name: getTypeName(def.Type_name())}, IsNotNull: hasNotNullConstraint(def.AllColumn_constraint()), }, }) @@ -113,14 +118,10 @@ func (c *cc) convertCreate_table_stmtContext(n *parser.Create_table_stmtContext) } for _, idef := range n.AllColumn_def() { if def, ok := idef.(*parser.Column_defContext); ok { - typeName := "any" - if def.Type_name() != nil { - typeName = def.Type_name().GetText() - } stmt.Cols = append(stmt.Cols, &ast.ColumnDef{ Colname: identifier(def.Column_name().GetText()), IsNotNull: hasNotNullConstraint(def.AllColumn_constraint()), - TypeName: &ast.TypeName{Name: typeName}, + TypeName: &ast.TypeName{Name: getTypeName(def.Type_name())}, }) } }