Skip to content

Commit 4428adf

Browse files
committed
fix(sqlite): handle comparison operators and sorting
1 parent d3b0872 commit 4428adf

File tree

1 file changed

+181
-7
lines changed

1 file changed

+181
-7
lines changed

internal/engine/sqlite/convert.go

Lines changed: 181 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -356,21 +356,103 @@ func (c *cc) convertComparison(n *parser.Expr_comparisonContext) ast.Node {
356356
return &ast.In{
357357
Expr: lexpr,
358358
List: rexprs,
359-
Not: false,
359+
Not: n.NOT_() != nil,
360360
Sel: nil,
361361
Location: n.GetStart().GetStart(),
362362
}
363363
}
364364

365+
operator := c.extractComparisonOperator(n)
366+
rexprIdx := 1
367+
rexpr := c.convert(n.Expr(rexprIdx))
368+
369+
// Special handling for IS NOT NULL where NOT NULL might be parsed as a unary expression
370+
if operator == "IS" && len(n.AllExpr()) > 1 {
371+
if rExpr, ok := n.Expr(1).(*parser.Expr_unaryContext); ok {
372+
// Check if this is a NOT NULL expression by looking at the text content
373+
text := rExpr.GetText()
374+
if strings.ToUpper(text) == "NOTNULL" || strings.ToUpper(text) == "NOT NULL" {
375+
operator = "IS NOT"
376+
rexpr = &ast.A_Const{Val: &ast.Null{}}
377+
}
378+
}
379+
}
380+
365381
return &ast.A_Expr{
366382
Name: &ast.List{
367383
Items: []ast.Node{
368-
&ast.String{Str: "="}, // TODO: add actual comparison
384+
&ast.String{Str: operator},
369385
},
370386
},
371387
Lexpr: lexpr,
372-
Rexpr: c.convert(n.Expr(1)),
388+
Rexpr: rexpr,
389+
}
390+
}
391+
392+
func (c *cc) extractComparisonOperator(n *parser.Expr_comparisonContext) string {
393+
switch {
394+
case n.LT2() != nil:
395+
return "<<"
396+
case n.GT2() != nil:
397+
return ">>"
398+
case n.AMP() != nil:
399+
return "&"
400+
case n.PIPE() != nil:
401+
return "|"
402+
case n.LT_EQ() != nil:
403+
return "<="
404+
case n.GT_EQ() != nil:
405+
return ">="
406+
case n.LT() != nil:
407+
return "<"
408+
case n.GT() != nil:
409+
return ">"
410+
case n.NOT_EQ1() != nil:
411+
return "!="
412+
case n.NOT_EQ2() != nil:
413+
return "<>"
414+
case n.ASSIGN() != nil || n.EQ() != nil:
415+
return "="
416+
case n.IS_() != nil:
417+
if n.NOT_() != nil {
418+
return "IS NOT"
419+
}
420+
return "IS"
421+
case n.LIKE_() != nil:
422+
if n.NOT_() != nil {
423+
return "NOT LIKE"
424+
}
425+
return "LIKE"
426+
case n.GLOB_() != nil:
427+
if n.NOT_() != nil {
428+
return "NOT GLOB"
429+
}
430+
return "GLOB"
431+
case n.MATCH_() != nil:
432+
if n.NOT_() != nil {
433+
return "NOT MATCH"
434+
}
435+
return "MATCH"
436+
case n.REGEXP_() != nil:
437+
if n.NOT_() != nil {
438+
return "NOT REGEXP"
439+
}
440+
return "REGEXP"
441+
}
442+
443+
var parts []string
444+
for _, child := range n.GetChildren() {
445+
if term, ok := child.(antlr.TerminalNode); ok {
446+
text := strings.TrimSpace(term.GetText())
447+
if text != "" {
448+
parts = append(parts, text)
449+
}
450+
}
451+
}
452+
if len(parts) > 0 {
453+
return strings.Join(parts, " ")
373454
}
455+
return "="
374456
}
375457

376458
func (c *cc) convertMultiSelect_stmtContext(n *parser.Select_stmtContext) ast.Node {
@@ -514,6 +596,11 @@ func (c *cc) convertMultiSelect_stmtContext(n *parser.Select_stmtContext) ast.No
514596
limitCount, limitOffset := c.convertLimit_stmtContext(n.Limit_stmt())
515597
selectStmt.LimitCount = limitCount
516598
selectStmt.LimitOffset = limitOffset
599+
if orderBy := n.Order_by_stmt(); orderBy != nil {
600+
if sortClause, ok := c.convertOrderby_stmtContext(orderBy).(*ast.List); ok {
601+
selectStmt.SortClause = sortClause
602+
}
603+
}
517604
selectStmt.WithClause = &ast.WithClause{Ctes: &ctes}
518605
return selectStmt
519606
}
@@ -626,10 +713,34 @@ func (c *cc) convertOrderby_stmtContext(n parser.IOrder_by_stmtContext) ast.Node
626713
if !ok {
627714
continue
628715
}
629-
list.Items = append(list.Items, &ast.CaseExpr{
630-
Xpr: c.convert(term.Expr()),
631-
Location: term.Expr().GetStart().GetStart(),
632-
})
716+
717+
expr := c.convert(term.Expr())
718+
sortBy := &ast.SortBy{
719+
Node: expr,
720+
SortbyDir: ast.SortByDirDefault,
721+
SortbyNulls: ast.SortByNullsDefault,
722+
UseOp: &ast.List{},
723+
}
724+
725+
if ascDescCtx := term.Asc_desc(); ascDescCtx != nil {
726+
if ascDesc, ok := ascDescCtx.(*parser.Asc_descContext); ok {
727+
if ascDesc.DESC_() != nil {
728+
sortBy.SortbyDir = ast.SortByDirDesc
729+
} else if ascDesc.ASC_() != nil {
730+
sortBy.SortbyDir = ast.SortByDirAsc
731+
}
732+
}
733+
}
734+
735+
if term.NULLS_() != nil {
736+
if term.FIRST_() != nil {
737+
sortBy.SortbyNulls = ast.SortByNullsFirst
738+
} else if term.LAST_() != nil {
739+
sortBy.SortbyNulls = ast.SortByNullsLast
740+
}
741+
}
742+
743+
list.Items = append(list.Items, sortBy)
633744
}
634745
return list
635746
}
@@ -1135,6 +1246,63 @@ func (c *cc) convertCase(n *parser.Expr_caseContext) ast.Node {
11351246
return e
11361247
}
11371248

1249+
func (c *cc) convertUnaryExpr(n *parser.Expr_unaryContext) ast.Node {
1250+
// Handle unary expressions like NOT NULL
1251+
children := n.GetChildren()
1252+
if len(children) >= 2 {
1253+
for i, child := range children {
1254+
if term, ok := child.(antlr.TerminalNode); ok {
1255+
if term.GetSymbol().GetTokenType() == parser.SQLiteParserNOT_ {
1256+
if i+1 < len(children) {
1257+
if nextTerm, ok := children[i+1].(antlr.TerminalNode); ok {
1258+
if nextTerm.GetSymbol().GetTokenType() == parser.SQLiteParserNULL_ {
1259+
return &ast.A_Const{Val: &ast.Null{}}
1260+
}
1261+
}
1262+
}
1263+
}
1264+
}
1265+
}
1266+
}
1267+
1268+
// For other unary expressions, try to convert the inner expression
1269+
if n.Expr() != nil {
1270+
return c.convert(n.Expr())
1271+
}
1272+
1273+
return todo("convertUnaryExpr", n)
1274+
}
1275+
1276+
func (c *cc) convertNullComparison(n *parser.Expr_null_compContext) ast.Node {
1277+
expr := c.convert(n.Expr())
1278+
1279+
var operator string
1280+
switch {
1281+
case n.ISNULL_() != nil:
1282+
operator = "IS NULL"
1283+
case n.NOTNULL_() != nil:
1284+
operator = "IS NOT NULL"
1285+
case n.NOT_() != nil && n.NULL_() != nil:
1286+
operator = "IS NOT NULL"
1287+
case n.NULL_() != nil:
1288+
operator = "IS NULL"
1289+
default:
1290+
operator = "IS NULL" // fallback
1291+
}
1292+
1293+
return &ast.A_Expr{
1294+
Name: &ast.List{
1295+
Items: []ast.Node{
1296+
&ast.String{Str: operator},
1297+
},
1298+
},
1299+
Lexpr: expr,
1300+
Rexpr: &ast.A_Const{
1301+
Val: &ast.Null{},
1302+
},
1303+
}
1304+
}
1305+
11381306
func (c *cc) convert(node node) ast.Node {
11391307
switch n := node.(type) {
11401308

@@ -1226,6 +1394,12 @@ func (c *cc) convert(node node) ast.Node {
12261394
case *parser.Expr_caseContext:
12271395
return c.convertCase(n)
12281396

1397+
case *parser.Expr_null_compContext:
1398+
return c.convertNullComparison(n)
1399+
1400+
case *parser.Expr_unaryContext:
1401+
return c.convertUnaryExpr(n)
1402+
12291403
default:
12301404
return todo("convert(case=default)", n)
12311405
}

0 commit comments

Comments
 (0)