11package io .substrait .isthmus ;
22
3+ import static org .junit .jupiter .api .Assertions .assertTrue ;
4+
5+ import io .substrait .isthmus .sql .SubstraitCreateStatementParser ;
36import org .junit .jupiter .api .Test ;
47import org .junit .jupiter .params .ParameterizedTest ;
58import org .junit .jupiter .params .provider .CsvSource ;
@@ -38,13 +41,31 @@ void is_not_false() throws Exception {
3841 void is_distinct_from (String left , String right ) throws Exception {
3942 String query = String .format ("SELECT (%s IS DISTINCT FROM %s) FROM numbers" , left , right );
4043 assertSqlSubstraitRelRoundTrip (query , CREATES );
44+
45+ // Assert logical rewrite exists
46+ io .substrait .plan .Plan plan =
47+ toSubstraitPlan (
48+ query , SubstraitCreateStatementParser .processCreateStatementsToCatalog (CREATES ));
49+ String planString = plan .toString ();
50+ assertTrue (
51+ planString .contains ("and" ) && planString .contains ("or" ) && planString .contains ("equal" ),
52+ "Expected Substrait plan to contain logical rewrite for IS DISTINCT FROM" );
4153 }
4254
4355 @ ParameterizedTest
4456 @ ValueSource (strings = {"int_a" , "int_b" , "double_a" , "double_b" })
4557 void is_distinct_from_null_vs_col (String column ) throws Exception {
4658 String query = String .format ("SELECT (NULL IS DISTINCT FROM %s) FROM numbers" , column );
4759 assertSqlSubstraitRelRoundTrip (query , CREATES );
60+
61+ // Assert logical rewrite exists
62+ io .substrait .plan .Plan plan =
63+ toSubstraitPlan (
64+ query , SubstraitCreateStatementParser .processCreateStatementsToCatalog (CREATES ));
65+ String planString = plan .toString ();
66+ assertTrue (
67+ planString .contains ("is_not_null" ),
68+ "Expected Substrait plan to contain logical rewrite for NULL IS DISTINCT FROM to IS NOT NULL" );
4869 }
4970
5071 @ ParameterizedTest
@@ -71,7 +92,7 @@ void least(String args) throws Exception {
7192 })
7293 void greatest (String args ) throws Exception {
7394 String join_args = String .join (", " , args );
74- String query = String .format ("SELECT LEAST (%s) FROM numbers" , join_args );
95+ String query = String .format ("SELECT GREATEST (%s) FROM numbers" , join_args );
7596 assertSqlSubstraitRelRoundTrip (query , CREATES );
7697 }
7798}
0 commit comments