From d3b0872b73707f3a8287fc8d1213b19385e2f85f Mon Sep 17 00:00:00 2001 From: Miguel Eduardo Gil Biraud Date: Mon, 6 Oct 2025 02:48:43 +0200 Subject: [PATCH 1/3] add test --- internal/engine/sqlite/convert_test.go | 433 +++++++++++++++++++++++++ 1 file changed, 433 insertions(+) create mode 100644 internal/engine/sqlite/convert_test.go diff --git a/internal/engine/sqlite/convert_test.go b/internal/engine/sqlite/convert_test.go new file mode 100644 index 0000000000..47da2cbc1b --- /dev/null +++ b/internal/engine/sqlite/convert_test.go @@ -0,0 +1,433 @@ +package sqlite + +import ( + "strings" + "testing" + + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/astutils" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +func TestConvertComparison(t *testing.T) { + p := NewParser() + + tests := []struct { + name string + sql string + expected string + }{ + // Basic comparison operators + { + name: "less than", + sql: "SELECT * FROM users WHERE age < 18", + expected: "<", + }, + { + name: "greater than", + sql: "SELECT * FROM users WHERE age > 65", + expected: ">", + }, + { + name: "less than or equal", + sql: "SELECT * FROM users WHERE age <= 18", + expected: "<=", + }, + { + name: "greater than or equal", + sql: "SELECT * FROM users WHERE age >= 65", + expected: ">=", + }, + { + name: "equals", + sql: "SELECT * FROM users WHERE status = 'active'", + expected: "=", + }, + { + name: "not equals (!=)", + sql: "SELECT * FROM users WHERE status != 'inactive'", + expected: "!=", + }, + { + name: "not equals (<>)", + sql: "SELECT * FROM users WHERE status <> 'inactive'", + expected: "<>", + }, + // Bit operations + { + name: "left shift", + sql: "SELECT * FROM users WHERE flags << 2", + expected: "<<", + }, + { + name: "right shift", + sql: "SELECT * FROM users WHERE flags >> 1", + expected: ">>", + }, + { + name: "bitwise and", + sql: "SELECT * FROM users WHERE flags & 4", + expected: "&", + }, + { + name: "bitwise or", + sql: "SELECT * FROM users WHERE flags | 8", + expected: "|", + }, + // IS operators + { + name: "is null", + sql: "SELECT * FROM users WHERE email IS NULL", + expected: "IS", + }, + { + name: "is not null", + sql: "SELECT * FROM users WHERE email IS NOT NULL", + expected: "IS NOT", + }, + // LIKE operators + { + name: "like", + sql: "SELECT * FROM users WHERE name LIKE 'John%'", + expected: "LIKE", + }, + { + name: "not like", + sql: "SELECT * FROM users WHERE name NOT LIKE 'Admin%'", + expected: "NOT LIKE", + }, + // GLOB operators + { + name: "glob", + sql: "SELECT * FROM users WHERE name GLOB 'J*'", + expected: "GLOB", + }, + { + name: "not glob", + sql: "SELECT * FROM users WHERE name NOT GLOB 'A*'", + expected: "NOT GLOB", + }, + // MATCH operators + { + name: "match", + sql: "SELECT * FROM users WHERE name MATCH 'pattern'", + expected: "MATCH", + }, + { + name: "not match", + sql: "SELECT * FROM users WHERE name NOT MATCH 'pattern'", + expected: "NOT MATCH", + }, + // REGEXP operators + { + name: "regexp", + sql: "SELECT * FROM users WHERE email REGEXP '.*@example\\.com'", + expected: "REGEXP", + }, + { + name: "not regexp", + sql: "SELECT * FROM users WHERE email NOT REGEXP '.*@spam\\.com'", + expected: "NOT REGEXP", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + stmts, err := p.Parse(strings.NewReader(tc.sql)) + if err != nil { + t.Fatalf("Failed to parse SQL: %v", err) + } + + if len(stmts) != 1 { + t.Fatalf("Expected 1 statement, got %d", len(stmts)) + } + + stmt := stmts[0].Raw.Stmt + selectStmt, ok := stmt.(*ast.SelectStmt) + if !ok { + t.Fatalf("Expected SelectStmt, got %T", stmt) + } + + // Find the comparison expression in the WHERE clause + if selectStmt.WhereClause == nil { + t.Fatal("Expected WHERE clause") + } + + var foundOperator string + astutils.Walk(astutils.VisitorFunc(func(node ast.Node) { + if aExpr, ok := node.(*ast.A_Expr); ok { + if aExpr.Name != nil && len(aExpr.Name.Items) > 0 { + if str, ok := aExpr.Name.Items[0].(*ast.String); ok { + foundOperator = str.Str + } + } + } + }), selectStmt.WhereClause) + + if foundOperator != tc.expected { + t.Errorf("Expected operator %q, got %q", tc.expected, foundOperator) + } + }) + } +} + +func TestConvertInOperation(t *testing.T) { + p := NewParser() + + tests := []struct { + name string + sql string + expectNot bool + }{ + { + name: "in operation", + sql: "SELECT * FROM users WHERE status IN ('active', 'pending')", + expectNot: false, + }, + { + name: "not in operation", + sql: "SELECT * FROM users WHERE status NOT IN ('inactive', 'deleted')", + expectNot: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + stmts, err := p.Parse(strings.NewReader(tc.sql)) + if err != nil { + t.Fatalf("Failed to parse SQL: %v", err) + } + + if len(stmts) != 1 { + t.Fatalf("Expected 1 statement, got %d", len(stmts)) + } + + stmt := stmts[0].Raw.Stmt + selectStmt, ok := stmt.(*ast.SelectStmt) + if !ok { + t.Fatalf("Expected SelectStmt, got %T", stmt) + } + + // Find the IN expression in the WHERE clause + if selectStmt.WhereClause == nil { + t.Fatal("Expected WHERE clause") + } + + var foundIn *ast.In + astutils.Walk(astutils.VisitorFunc(func(node ast.Node) { + if inExpr, ok := node.(*ast.In); ok { + foundIn = inExpr + } + }), selectStmt.WhereClause) + + if foundIn == nil { + t.Fatal("Expected IN expression") + } + + if foundIn.Not != tc.expectNot { + t.Errorf("Expected NOT=%v, got NOT=%v", tc.expectNot, foundIn.Not) + } + }) + } +} + +func TestConvertOrderBy(t *testing.T) { + p := NewParser() + + tests := []struct { + name string + sql string + expectedDirs []ast.SortByDir + expectedNulls []ast.SortByNulls + }{ + { + name: "order by default", + sql: "SELECT * FROM users ORDER BY name", + expectedDirs: []ast.SortByDir{ast.SortByDirDefault}, + expectedNulls: []ast.SortByNulls{ast.SortByNullsDefault}, + }, + { + name: "order by asc", + sql: "SELECT * FROM users ORDER BY name ASC", + expectedDirs: []ast.SortByDir{ast.SortByDirAsc}, + expectedNulls: []ast.SortByNulls{ast.SortByNullsDefault}, + }, + { + name: "order by desc", + sql: "SELECT * FROM users ORDER BY age DESC", + expectedDirs: []ast.SortByDir{ast.SortByDirDesc}, + expectedNulls: []ast.SortByNulls{ast.SortByNullsDefault}, + }, + { + name: "order by nulls first", + sql: "SELECT * FROM users ORDER BY email NULLS FIRST", + expectedDirs: []ast.SortByDir{ast.SortByDirDefault}, + expectedNulls: []ast.SortByNulls{ast.SortByNullsFirst}, + }, + { + name: "order by nulls last", + sql: "SELECT * FROM users ORDER BY email NULLS LAST", + expectedDirs: []ast.SortByDir{ast.SortByDirDefault}, + expectedNulls: []ast.SortByNulls{ast.SortByNullsLast}, + }, + { + name: "order by desc nulls first", + sql: "SELECT * FROM users ORDER BY score DESC NULLS FIRST", + expectedDirs: []ast.SortByDir{ast.SortByDirDesc}, + expectedNulls: []ast.SortByNulls{ast.SortByNullsFirst}, + }, + { + name: "order by multiple columns", + sql: "SELECT * FROM users ORDER BY name ASC, age DESC", + expectedDirs: []ast.SortByDir{ast.SortByDirAsc, ast.SortByDirDesc}, + expectedNulls: []ast.SortByNulls{ast.SortByNullsDefault, ast.SortByNullsDefault}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + stmts, err := p.Parse(strings.NewReader(tc.sql)) + if err != nil { + t.Fatalf("Failed to parse SQL: %v", err) + } + + if len(stmts) != 1 { + t.Fatalf("Expected 1 statement, got %d", len(stmts)) + } + + stmt := stmts[0].Raw.Stmt + selectStmt, ok := stmt.(*ast.SelectStmt) + if !ok { + t.Fatalf("Expected SelectStmt, got %T", stmt) + } + + // Check if SortClause is properly set + if selectStmt.SortClause == nil { + t.Fatal("Expected SortClause to be set") + } + + if len(selectStmt.SortClause.Items) != len(tc.expectedDirs) { + t.Fatalf("Expected %d sort items, got %d", len(tc.expectedDirs), len(selectStmt.SortClause.Items)) + } + + // Check each sort item + for i, item := range selectStmt.SortClause.Items { + sortBy, ok := item.(*ast.SortBy) + if !ok { + t.Fatalf("Expected SortBy at index %d, got %T", i, item) + } + + if sortBy.SortbyDir != tc.expectedDirs[i] { + t.Errorf("Expected SortbyDir %v at index %d, got %v", tc.expectedDirs[i], i, sortBy.SortbyDir) + } + + if sortBy.SortbyNulls != tc.expectedNulls[i] { + t.Errorf("Expected SortbyNulls %v at index %d, got %v", tc.expectedNulls[i], i, sortBy.SortbyNulls) + } + } + }) + } +} + +func TestConvertComplexQueries(t *testing.T) { + p := NewParser() + + tests := []struct { + name string + sql string + }{ + { + name: "complex where with multiple operators", + sql: "SELECT * FROM users WHERE age >= 18 AND status = 'active' AND email NOT LIKE '%@spam.com' ORDER BY name ASC, age DESC", + }, + { + name: "query with IN and ORDER BY", + sql: "SELECT * FROM products WHERE category_id IN (1, 2, 3) AND price > 100 ORDER BY price DESC NULLS LAST", + }, + { + name: "query with IS NOT and complex ordering", + sql: "SELECT * FROM orders WHERE processed_at IS NOT NULL AND total >= 50 ORDER BY created_at DESC, total ASC NULLS FIRST", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + stmts, err := p.Parse(strings.NewReader(tc.sql)) + if err != nil { + t.Fatalf("Failed to parse SQL: %v", err) + } + + if len(stmts) != 1 { + t.Fatalf("Expected 1 statement, got %d", len(stmts)) + } + + stmt := stmts[0].Raw.Stmt + selectStmt, ok := stmt.(*ast.SelectStmt) + if !ok { + t.Fatalf("Expected SelectStmt, got %T", stmt) + } + + // Basic checks to ensure parsing didn't fail + if selectStmt.WhereClause == nil { + t.Error("Expected WHERE clause") + } + + if selectStmt.SortClause == nil { + t.Error("Expected ORDER BY clause") + } + + // Verify no TODO nodes were created (which would indicate parsing failures) + var foundTodo bool + astutils.Walk(astutils.VisitorFunc(func(node ast.Node) { + if _, ok := node.(*ast.TODO); ok { + foundTodo = true + } + }), selectStmt) + + if foundTodo { + t.Error("Found TODO node, indicating incomplete parsing") + } + }) + } +} + +// Helper function to extract all A_Expr operators from a WHERE clause +func extractOperators(whereClause ast.Node) []string { + var operators []string + astutils.Walk(astutils.VisitorFunc(func(node ast.Node) { + if aExpr, ok := node.(*ast.A_Expr); ok { + if aExpr.Name != nil && len(aExpr.Name.Items) > 0 { + if str, ok := aExpr.Name.Items[0].(*ast.String); ok { + operators = append(operators, str.Str) + } + } + } + }), whereClause) + return operators +} + +func TestExtractComparisonOperator(t *testing.T) { + // Test that our helper can extract multiple operators from complex queries + p := NewParser() + + sql := "SELECT * FROM users WHERE age >= 18 AND status != 'inactive' AND email LIKE '%@company.com'" + stmts, err := p.Parse(strings.NewReader(sql)) + if err != nil { + t.Fatalf("Failed to parse SQL: %v", err) + } + + stmt := stmts[0].Raw.Stmt + selectStmt, ok := stmt.(*ast.SelectStmt) + if !ok { + t.Fatalf("Expected SelectStmt, got %T", stmt) + } + + operators := extractOperators(selectStmt.WhereClause) + expectedOperators := []string{">=", "!=", "LIKE"} + + if diff := cmp.Diff(expectedOperators, operators, cmpopts.EquateEmpty()); diff != "" { + t.Errorf("operators mismatch:\n%s", diff) + } +} From 4428adf2918d338c1066fce3fd3819257de89150 Mon Sep 17 00:00:00 2001 From: Miguel Eduardo Gil Biraud Date: Mon, 6 Oct 2025 02:49:50 +0200 Subject: [PATCH 2/3] fix(sqlite): handle comparison operators and sorting --- internal/engine/sqlite/convert.go | 188 ++++++++++++++++++++++++++++-- 1 file changed, 181 insertions(+), 7 deletions(-) diff --git a/internal/engine/sqlite/convert.go b/internal/engine/sqlite/convert.go index 658a9d7f33..3c8ab6a302 100644 --- a/internal/engine/sqlite/convert.go +++ b/internal/engine/sqlite/convert.go @@ -356,21 +356,103 @@ func (c *cc) convertComparison(n *parser.Expr_comparisonContext) ast.Node { return &ast.In{ Expr: lexpr, List: rexprs, - Not: false, + Not: n.NOT_() != nil, Sel: nil, Location: n.GetStart().GetStart(), } } + operator := c.extractComparisonOperator(n) + rexprIdx := 1 + rexpr := c.convert(n.Expr(rexprIdx)) + + // Special handling for IS NOT NULL where NOT NULL might be parsed as a unary expression + if operator == "IS" && len(n.AllExpr()) > 1 { + if rExpr, ok := n.Expr(1).(*parser.Expr_unaryContext); ok { + // Check if this is a NOT NULL expression by looking at the text content + text := rExpr.GetText() + if strings.ToUpper(text) == "NOTNULL" || strings.ToUpper(text) == "NOT NULL" { + operator = "IS NOT" + rexpr = &ast.A_Const{Val: &ast.Null{}} + } + } + } + return &ast.A_Expr{ Name: &ast.List{ Items: []ast.Node{ - &ast.String{Str: "="}, // TODO: add actual comparison + &ast.String{Str: operator}, }, }, Lexpr: lexpr, - Rexpr: c.convert(n.Expr(1)), + Rexpr: rexpr, + } +} + +func (c *cc) extractComparisonOperator(n *parser.Expr_comparisonContext) string { + switch { + case n.LT2() != nil: + return "<<" + case n.GT2() != nil: + return ">>" + case n.AMP() != nil: + return "&" + case n.PIPE() != nil: + return "|" + case n.LT_EQ() != nil: + return "<=" + case n.GT_EQ() != nil: + return ">=" + case n.LT() != nil: + return "<" + case n.GT() != nil: + return ">" + case n.NOT_EQ1() != nil: + return "!=" + case n.NOT_EQ2() != nil: + return "<>" + case n.ASSIGN() != nil || n.EQ() != nil: + return "=" + case n.IS_() != nil: + if n.NOT_() != nil { + return "IS NOT" + } + return "IS" + case n.LIKE_() != nil: + if n.NOT_() != nil { + return "NOT LIKE" + } + return "LIKE" + case n.GLOB_() != nil: + if n.NOT_() != nil { + return "NOT GLOB" + } + return "GLOB" + case n.MATCH_() != nil: + if n.NOT_() != nil { + return "NOT MATCH" + } + return "MATCH" + case n.REGEXP_() != nil: + if n.NOT_() != nil { + return "NOT REGEXP" + } + return "REGEXP" + } + + var parts []string + for _, child := range n.GetChildren() { + if term, ok := child.(antlr.TerminalNode); ok { + text := strings.TrimSpace(term.GetText()) + if text != "" { + parts = append(parts, text) + } + } + } + if len(parts) > 0 { + return strings.Join(parts, " ") } + return "=" } func (c *cc) convertMultiSelect_stmtContext(n *parser.Select_stmtContext) ast.Node { @@ -514,6 +596,11 @@ func (c *cc) convertMultiSelect_stmtContext(n *parser.Select_stmtContext) ast.No limitCount, limitOffset := c.convertLimit_stmtContext(n.Limit_stmt()) selectStmt.LimitCount = limitCount selectStmt.LimitOffset = limitOffset + if orderBy := n.Order_by_stmt(); orderBy != nil { + if sortClause, ok := c.convertOrderby_stmtContext(orderBy).(*ast.List); ok { + selectStmt.SortClause = sortClause + } + } selectStmt.WithClause = &ast.WithClause{Ctes: &ctes} return selectStmt } @@ -626,10 +713,34 @@ func (c *cc) convertOrderby_stmtContext(n parser.IOrder_by_stmtContext) ast.Node if !ok { continue } - list.Items = append(list.Items, &ast.CaseExpr{ - Xpr: c.convert(term.Expr()), - Location: term.Expr().GetStart().GetStart(), - }) + + expr := c.convert(term.Expr()) + sortBy := &ast.SortBy{ + Node: expr, + SortbyDir: ast.SortByDirDefault, + SortbyNulls: ast.SortByNullsDefault, + UseOp: &ast.List{}, + } + + if ascDescCtx := term.Asc_desc(); ascDescCtx != nil { + if ascDesc, ok := ascDescCtx.(*parser.Asc_descContext); ok { + if ascDesc.DESC_() != nil { + sortBy.SortbyDir = ast.SortByDirDesc + } else if ascDesc.ASC_() != nil { + sortBy.SortbyDir = ast.SortByDirAsc + } + } + } + + if term.NULLS_() != nil { + if term.FIRST_() != nil { + sortBy.SortbyNulls = ast.SortByNullsFirst + } else if term.LAST_() != nil { + sortBy.SortbyNulls = ast.SortByNullsLast + } + } + + list.Items = append(list.Items, sortBy) } return list } @@ -1135,6 +1246,63 @@ func (c *cc) convertCase(n *parser.Expr_caseContext) ast.Node { return e } +func (c *cc) convertUnaryExpr(n *parser.Expr_unaryContext) ast.Node { + // Handle unary expressions like NOT NULL + children := n.GetChildren() + if len(children) >= 2 { + for i, child := range children { + if term, ok := child.(antlr.TerminalNode); ok { + if term.GetSymbol().GetTokenType() == parser.SQLiteParserNOT_ { + if i+1 < len(children) { + if nextTerm, ok := children[i+1].(antlr.TerminalNode); ok { + if nextTerm.GetSymbol().GetTokenType() == parser.SQLiteParserNULL_ { + return &ast.A_Const{Val: &ast.Null{}} + } + } + } + } + } + } + } + + // For other unary expressions, try to convert the inner expression + if n.Expr() != nil { + return c.convert(n.Expr()) + } + + return todo("convertUnaryExpr", n) +} + +func (c *cc) convertNullComparison(n *parser.Expr_null_compContext) ast.Node { + expr := c.convert(n.Expr()) + + var operator string + switch { + case n.ISNULL_() != nil: + operator = "IS NULL" + case n.NOTNULL_() != nil: + operator = "IS NOT NULL" + case n.NOT_() != nil && n.NULL_() != nil: + operator = "IS NOT NULL" + case n.NULL_() != nil: + operator = "IS NULL" + default: + operator = "IS NULL" // fallback + } + + return &ast.A_Expr{ + Name: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: operator}, + }, + }, + Lexpr: expr, + Rexpr: &ast.A_Const{ + Val: &ast.Null{}, + }, + } +} + func (c *cc) convert(node node) ast.Node { switch n := node.(type) { @@ -1226,6 +1394,12 @@ func (c *cc) convert(node node) ast.Node { case *parser.Expr_caseContext: return c.convertCase(n) + case *parser.Expr_null_compContext: + return c.convertNullComparison(n) + + case *parser.Expr_unaryContext: + return c.convertUnaryExpr(n) + default: return todo("convert(case=default)", n) } From aecdc35fa2d9b953d0938b5d61e902919de17f49 Mon Sep 17 00:00:00 2001 From: Miguel Eduardo Gil Biraud Date: Mon, 6 Oct 2025 04:09:54 +0200 Subject: [PATCH 3/3] fix(sqlite): update end2end tests and fix not exists handling align with generation for postgres --- .../select_exists/sqlite/go/query.sql.go | 8 +-- .../select_not_exists/sqlite/go/query.sql.go | 10 +-- internal/engine/sqlite/convert.go | 67 ++++++++++++++----- 3 files changed, 60 insertions(+), 25 deletions(-) diff --git a/internal/endtoend/testdata/select_exists/sqlite/go/query.sql.go b/internal/endtoend/testdata/select_exists/sqlite/go/query.sql.go index e22e5b6f33..b30fa7d95a 100644 --- a/internal/endtoend/testdata/select_exists/sqlite/go/query.sql.go +++ b/internal/endtoend/testdata/select_exists/sqlite/go/query.sql.go @@ -21,9 +21,9 @@ SELECT ) ` -func (q *Queries) BarExists(ctx context.Context, id int64) (int64, error) { +func (q *Queries) BarExists(ctx context.Context, id int64) (bool, error) { row := q.db.QueryRowContext(ctx, barExists, id) - var column_1 int64 - err := row.Scan(&column_1) - return column_1, err + var exists bool + err := row.Scan(&exists) + return exists, err } diff --git a/internal/endtoend/testdata/select_not_exists/sqlite/go/query.sql.go b/internal/endtoend/testdata/select_not_exists/sqlite/go/query.sql.go index ee1b8e548b..fb2af2fd04 100644 --- a/internal/endtoend/testdata/select_not_exists/sqlite/go/query.sql.go +++ b/internal/endtoend/testdata/select_not_exists/sqlite/go/query.sql.go @@ -21,9 +21,9 @@ SELECT ) ` -func (q *Queries) BarNotExists(ctx context.Context) (interface{}, error) { - row := q.db.QueryRowContext(ctx, barNotExists) - var column_1 interface{} - err := row.Scan(&column_1) - return column_1, err +func (q *Queries) BarNotExists(ctx context.Context, id int64) (bool, error) { + row := q.db.QueryRowContext(ctx, barNotExists, id) + var not_exists bool + err := row.Scan(¬_exists) + return not_exists, err } diff --git a/internal/engine/sqlite/convert.go b/internal/engine/sqlite/convert.go index 3c8ab6a302..c389c2d356 100644 --- a/internal/engine/sqlite/convert.go +++ b/internal/engine/sqlite/convert.go @@ -927,6 +927,34 @@ func (c *cc) convertParam(n *parser.Expr_bindContext) ast.Node { } func (c *cc) convertInSelectNode(n *parser.Expr_in_selectContext) ast.Node { + // Check if this is an EXISTS or NOT EXISTS expression + if n.EXISTS_() != nil { + sublink := &ast.SubLink{ + SubLinkType: ast.EXISTS_SUBLINK, + Subselect: c.convert(n.Select_stmt()), + Location: n.GetStart().GetStart(), + } + + notExists := n.NOT_() != nil + if !notExists && n.GetStart() != nil { + notExists = n.GetStart().GetTokenType() == parser.SQLiteParserNOT_ + } + + // If NOT EXISTS, wrap in a BoolExpr with NOT + if notExists { + return &ast.BoolExpr{ + Boolop: ast.BoolExprTypeNot, + Args: &ast.List{ + Items: []ast.Node{sublink}, + }, + Location: n.GetStart().GetStart(), + } + } + + return sublink + } + + // Handle other IN expressions (original behavior) return c.convert(n.Select_stmt()) } @@ -1247,27 +1275,34 @@ func (c *cc) convertCase(n *parser.Expr_caseContext) ast.Node { } func (c *cc) convertUnaryExpr(n *parser.Expr_unaryContext) ast.Node { - // Handle unary expressions like NOT NULL - children := n.GetChildren() - if len(children) >= 2 { - for i, child := range children { - if term, ok := child.(antlr.TerminalNode); ok { - if term.GetSymbol().GetTokenType() == parser.SQLiteParserNOT_ { - if i+1 < len(children) { - if nextTerm, ok := children[i+1].(antlr.TerminalNode); ok { - if nextTerm.GetSymbol().GetTokenType() == parser.SQLiteParserNULL_ { - return &ast.A_Const{Val: &ast.Null{}} - } - } - } + if unary := n.Unary_operator(); unary != nil && unary.NOT_() != nil { + innerExpr := n.Expr() + if innerExpr != nil { + if strings.EqualFold(innerExpr.GetText(), "NULL") { + return &ast.A_Const{Val: &ast.Null{}} + } + if existsCtx, ok := innerExpr.(*parser.Expr_in_selectContext); ok { + inner := c.convertInSelectNode(existsCtx) + if boolNode, ok := inner.(*ast.BoolExpr); ok { + return boolNode + } + return &ast.BoolExpr{ + Boolop: ast.BoolExprTypeNot, + Args: &ast.List{Items: []ast.Node{inner}}, + Location: n.GetStart().GetStart(), } } + inner := c.convert(innerExpr) + return &ast.BoolExpr{ + Boolop: ast.BoolExprTypeNot, + Args: &ast.List{Items: []ast.Node{inner}}, + Location: n.GetStart().GetStart(), + } } } - // For other unary expressions, try to convert the inner expression - if n.Expr() != nil { - return c.convert(n.Expr()) + if expr := n.Expr(); expr != nil { + return c.convert(expr) } return todo("convertUnaryExpr", n)