Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

223 changes: 216 additions & 7 deletions internal/engine/sqlite/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,21 +356,103 @@ func (c *cc) convertComparison(n *parser.Expr_comparisonContext) ast.Node {
return &ast.In{
Expr: lexpr,
List: rexprs,
Not: false,
Not: n.NOT_() != nil,
Sel: nil,
Location: n.GetStart().GetStart(),
}
}

operator := c.extractComparisonOperator(n)
rexprIdx := 1
rexpr := c.convert(n.Expr(rexprIdx))

// Special handling for IS NOT NULL where NOT NULL might be parsed as a unary expression
if operator == "IS" && len(n.AllExpr()) > 1 {
if rExpr, ok := n.Expr(1).(*parser.Expr_unaryContext); ok {
// Check if this is a NOT NULL expression by looking at the text content
text := rExpr.GetText()
if strings.ToUpper(text) == "NOTNULL" || strings.ToUpper(text) == "NOT NULL" {
operator = "IS NOT"
rexpr = &ast.A_Const{Val: &ast.Null{}}
}
}
}

return &ast.A_Expr{
Name: &ast.List{
Items: []ast.Node{
&ast.String{Str: "="}, // TODO: add actual comparison
&ast.String{Str: operator},
},
},
Lexpr: lexpr,
Rexpr: c.convert(n.Expr(1)),
Rexpr: rexpr,
}
}

func (c *cc) extractComparisonOperator(n *parser.Expr_comparisonContext) string {
switch {
case n.LT2() != nil:
return "<<"
case n.GT2() != nil:
return ">>"
case n.AMP() != nil:
return "&"
case n.PIPE() != nil:
return "|"
case n.LT_EQ() != nil:
return "<="
case n.GT_EQ() != nil:
return ">="
case n.LT() != nil:
return "<"
case n.GT() != nil:
return ">"
case n.NOT_EQ1() != nil:
return "!="
case n.NOT_EQ2() != nil:
return "<>"
case n.ASSIGN() != nil || n.EQ() != nil:
return "="
case n.IS_() != nil:
if n.NOT_() != nil {
return "IS NOT"
}
return "IS"
case n.LIKE_() != nil:
if n.NOT_() != nil {
return "NOT LIKE"
}
return "LIKE"
case n.GLOB_() != nil:
if n.NOT_() != nil {
return "NOT GLOB"
}
return "GLOB"
case n.MATCH_() != nil:
if n.NOT_() != nil {
return "NOT MATCH"
}
return "MATCH"
case n.REGEXP_() != nil:
if n.NOT_() != nil {
return "NOT REGEXP"
}
return "REGEXP"
}

var parts []string
for _, child := range n.GetChildren() {
if term, ok := child.(antlr.TerminalNode); ok {
text := strings.TrimSpace(term.GetText())
if text != "" {
parts = append(parts, text)
}
}
}
if len(parts) > 0 {
return strings.Join(parts, " ")
}
return "="
}

func (c *cc) convertMultiSelect_stmtContext(n *parser.Select_stmtContext) ast.Node {
Expand Down Expand Up @@ -514,6 +596,11 @@ func (c *cc) convertMultiSelect_stmtContext(n *parser.Select_stmtContext) ast.No
limitCount, limitOffset := c.convertLimit_stmtContext(n.Limit_stmt())
selectStmt.LimitCount = limitCount
selectStmt.LimitOffset = limitOffset
if orderBy := n.Order_by_stmt(); orderBy != nil {
if sortClause, ok := c.convertOrderby_stmtContext(orderBy).(*ast.List); ok {
selectStmt.SortClause = sortClause
}
}
selectStmt.WithClause = &ast.WithClause{Ctes: &ctes}
return selectStmt
}
Expand Down Expand Up @@ -626,10 +713,34 @@ func (c *cc) convertOrderby_stmtContext(n parser.IOrder_by_stmtContext) ast.Node
if !ok {
continue
}
list.Items = append(list.Items, &ast.CaseExpr{
Xpr: c.convert(term.Expr()),
Location: term.Expr().GetStart().GetStart(),
})

expr := c.convert(term.Expr())
sortBy := &ast.SortBy{
Node: expr,
SortbyDir: ast.SortByDirDefault,
SortbyNulls: ast.SortByNullsDefault,
UseOp: &ast.List{},
}

if ascDescCtx := term.Asc_desc(); ascDescCtx != nil {
if ascDesc, ok := ascDescCtx.(*parser.Asc_descContext); ok {
if ascDesc.DESC_() != nil {
sortBy.SortbyDir = ast.SortByDirDesc
} else if ascDesc.ASC_() != nil {
sortBy.SortbyDir = ast.SortByDirAsc
}
}
}

if term.NULLS_() != nil {
if term.FIRST_() != nil {
sortBy.SortbyNulls = ast.SortByNullsFirst
} else if term.LAST_() != nil {
sortBy.SortbyNulls = ast.SortByNullsLast
}
}

list.Items = append(list.Items, sortBy)
}
return list
}
Expand Down Expand Up @@ -816,6 +927,34 @@ func (c *cc) convertParam(n *parser.Expr_bindContext) ast.Node {
}

func (c *cc) convertInSelectNode(n *parser.Expr_in_selectContext) ast.Node {
// Check if this is an EXISTS or NOT EXISTS expression
if n.EXISTS_() != nil {
sublink := &ast.SubLink{
SubLinkType: ast.EXISTS_SUBLINK,
Subselect: c.convert(n.Select_stmt()),
Location: n.GetStart().GetStart(),
}

notExists := n.NOT_() != nil
if !notExists && n.GetStart() != nil {
notExists = n.GetStart().GetTokenType() == parser.SQLiteParserNOT_
}

// If NOT EXISTS, wrap in a BoolExpr with NOT
if notExists {
return &ast.BoolExpr{
Boolop: ast.BoolExprTypeNot,
Args: &ast.List{
Items: []ast.Node{sublink},
},
Location: n.GetStart().GetStart(),
}
}

return sublink
}

// Handle other IN expressions (original behavior)
return c.convert(n.Select_stmt())
}

Expand Down Expand Up @@ -1135,6 +1274,70 @@ func (c *cc) convertCase(n *parser.Expr_caseContext) ast.Node {
return e
}

func (c *cc) convertUnaryExpr(n *parser.Expr_unaryContext) ast.Node {
if unary := n.Unary_operator(); unary != nil && unary.NOT_() != nil {
innerExpr := n.Expr()
if innerExpr != nil {
if strings.EqualFold(innerExpr.GetText(), "NULL") {
return &ast.A_Const{Val: &ast.Null{}}
}
if existsCtx, ok := innerExpr.(*parser.Expr_in_selectContext); ok {
inner := c.convertInSelectNode(existsCtx)
if boolNode, ok := inner.(*ast.BoolExpr); ok {
return boolNode
}
return &ast.BoolExpr{
Boolop: ast.BoolExprTypeNot,
Args: &ast.List{Items: []ast.Node{inner}},
Location: n.GetStart().GetStart(),
}
}
inner := c.convert(innerExpr)
return &ast.BoolExpr{
Boolop: ast.BoolExprTypeNot,
Args: &ast.List{Items: []ast.Node{inner}},
Location: n.GetStart().GetStart(),
}
}
}

if expr := n.Expr(); expr != nil {
return c.convert(expr)
}

return todo("convertUnaryExpr", n)
}

func (c *cc) convertNullComparison(n *parser.Expr_null_compContext) ast.Node {
expr := c.convert(n.Expr())

var operator string
switch {
case n.ISNULL_() != nil:
operator = "IS NULL"
case n.NOTNULL_() != nil:
operator = "IS NOT NULL"
case n.NOT_() != nil && n.NULL_() != nil:
operator = "IS NOT NULL"
case n.NULL_() != nil:
operator = "IS NULL"
default:
operator = "IS NULL" // fallback
}

return &ast.A_Expr{
Name: &ast.List{
Items: []ast.Node{
&ast.String{Str: operator},
},
},
Lexpr: expr,
Rexpr: &ast.A_Const{
Val: &ast.Null{},
},
}
}

func (c *cc) convert(node node) ast.Node {
switch n := node.(type) {

Expand Down Expand Up @@ -1226,6 +1429,12 @@ func (c *cc) convert(node node) ast.Node {
case *parser.Expr_caseContext:
return c.convertCase(n)

case *parser.Expr_null_compContext:
return c.convertNullComparison(n)

case *parser.Expr_unaryContext:
return c.convertUnaryExpr(n)

default:
return todo("convert(case=default)", n)
}
Expand Down
Loading
Loading