Skip to content

Commit 96541a9

Browse files
authored
feat(isthmus): mapping of positional scalar fns (#610)
1 parent 22448d1 commit 96541a9

File tree

3 files changed

+146
-2
lines changed

3 files changed

+146
-2
lines changed

isthmus/src/main/java/io/substrait/isthmus/calcite/SubstraitOperatorTable.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,13 @@ public class SubstraitOperatorTable implements SqlOperatorTable {
4747
// functions
4848
private static final SqlOperatorTable LIBRARY_OPERATOR_TABLE =
4949
SqlLibraryOperatorTableFactory.INSTANCE.getOperatorTable(
50-
EnumSet.of(SqlLibrary.HIVE, SqlLibrary.SPARK, SqlLibrary.ALL));
50+
EnumSet.of(
51+
SqlLibrary.HIVE,
52+
SqlLibrary.SPARK,
53+
SqlLibrary.ALL,
54+
SqlLibrary.BIG_QUERY,
55+
SqlLibrary.SNOWFLAKE,
56+
SqlLibrary.STANDARD));
5157

5258
private static final SqlOperatorTable STANDARD_OPERATOR_TABLE = SqlStdOperatorTable.instance();
5359

isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
public class FunctionMappings {
1313
// Static list of signature mapping between Calcite SQL operators and Substrait base function
1414
// names.
15+
1516
public static final ImmutableList<Sig> SCALAR_SIGS =
1617
ImmutableList.<Sig>builder()
1718
.add(
@@ -88,7 +89,11 @@ public class FunctionMappings {
8889
s(SqlLibraryOperators.LEAST, "least"),
8990
s(SqlLibraryOperators.GREATEST, "greatest"),
9091
s(SqlStdOperatorTable.BIT_LEFT_SHIFT, "shift_left"),
91-
s(SqlStdOperatorTable.LEFTSHIFT, "shift_left"))
92+
s(SqlStdOperatorTable.LEFTSHIFT, "shift_left"),
93+
s(SqlLibraryOperators.STARTS_WITH, "starts_with"),
94+
s(SqlLibraryOperators.ENDS_WITH, "ends_with"),
95+
s(SqlLibraryOperators.CONTAINS_SUBSTR, "contains"),
96+
s(SqlStdOperatorTable.POSITION, "strpos"))
9297
.build();
9398

9499
public static final ImmutableList<Sig> AGGREGATE_SIGS =

isthmus/src/test/java/io/substrait/isthmus/StringFunctionTest.java

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import io.substrait.plan.Plan;
66
import org.apache.calcite.sql.parser.SqlParseException;
77
import org.junit.jupiter.params.ParameterizedTest;
8+
import org.junit.jupiter.params.provider.CsvSource;
89
import org.junit.jupiter.params.provider.ValueSource;
910

1011
public final class StringFunctionTest extends PlanTestBase {
@@ -137,4 +138,136 @@ private void assertSqlRoundTrip(String sql) throws SqlParseException {
137138
Plan plan = assertProtoPlanRoundrip(sql, new SqlToSubstrait(), CREATES);
138139
assertDoesNotThrow(() -> toSql(plan), "Substrait plan to SQL");
139140
}
141+
142+
@ParameterizedTest
143+
@CsvSource({"c16, c16", "c16, vc32", "c16, vc", "vc32, vc32", "vc32, vc", "vc, vc"})
144+
void testStarts_With(String left, String right) throws Exception {
145+
146+
String query = String.format("SELECT STARTS_WITH(%s, %s) FROM strings", left, right);
147+
148+
assertSqlRoundTrip(query);
149+
}
150+
151+
@ParameterizedTest
152+
@CsvSource(
153+
value = {"'start', vc", "vc, 'end'"},
154+
quoteCharacter = '`')
155+
void testStarts_WithLiteral(String left, String right) throws Exception {
156+
String query = String.format("SELECT STARTS_WITH(%s, %s) FROM strings", left, right);
157+
assertSqlRoundTrip(query);
158+
}
159+
160+
@ParameterizedTest
161+
@CsvSource({"c16, c16", "c16, vc32", "c16, vc", "vc32, vc32", "vc32, vc", "vc, vc"})
162+
void testStartsWith(String left, String right) throws Exception {
163+
164+
String query = String.format("SELECT STARTSWITH(%s, %s) FROM strings", left, right);
165+
166+
assertSqlRoundTrip(query);
167+
}
168+
169+
@ParameterizedTest
170+
@CsvSource(
171+
value = {"'start', vc", "vc, 'end'"},
172+
quoteCharacter = '`')
173+
void testStartsWithLiteral(String left, String right) throws Exception {
174+
String query = String.format("SELECT STARTSWITH(%s, %s) FROM strings", left, right);
175+
assertSqlRoundTrip(query);
176+
}
177+
178+
@ParameterizedTest
179+
@CsvSource({"c16, c16", "c16, vc32", "c16, vc", "vc32, vc32", "vc32, vc", "vc, vc"})
180+
void testEnds_With(String left, String right) throws Exception {
181+
182+
String query = String.format("SELECT ENDS_WITH(%s, %s) FROM strings", left, right);
183+
184+
assertSqlRoundTrip(query);
185+
}
186+
187+
@ParameterizedTest
188+
@CsvSource(
189+
value = {"'start', vc", "vc, 'end'"},
190+
quoteCharacter = '`')
191+
void testEnds_WithLiteral(String left, String right) throws Exception {
192+
String query = String.format("SELECT ENDS_WITH(%s, %s) FROM strings", left, right);
193+
assertSqlRoundTrip(query);
194+
}
195+
196+
@ParameterizedTest
197+
@CsvSource({"c16, c16", "c16, vc32", "c16, vc", "vc32, vc32", "vc32, vc", "vc, vc"})
198+
void testEndsWith(String left, String right) throws Exception {
199+
200+
String query = String.format("SELECT ENDSWITH(%s, %s) FROM strings", left, right);
201+
202+
assertSqlRoundTrip(query);
203+
}
204+
205+
@ParameterizedTest
206+
@CsvSource(
207+
value = {"'start', vc", "vc, 'end'"},
208+
quoteCharacter = '`')
209+
void testEndsWithLiteral(String left, String right) throws Exception {
210+
String query = String.format("SELECT ENDSWITH(%s, %s) FROM strings", left, right);
211+
assertSqlRoundTrip(query);
212+
}
213+
214+
@ParameterizedTest
215+
@CsvSource({"c16, c16", "c16, vc32", "c16, vc", "vc32, vc32", "vc32, vc", "vc, vc"})
216+
void testContains(String left, String right) throws Exception {
217+
218+
String query = String.format("SELECT CONTAINS_SUBSTR(%s, %s) FROM strings", left, right);
219+
220+
assertSqlRoundTrip(query);
221+
}
222+
223+
@ParameterizedTest
224+
@CsvSource(
225+
value = {"'start', vc", "vc, 'end'"},
226+
quoteCharacter = '`')
227+
void testContainsWithLiteral(String left, String right) throws Exception {
228+
229+
String query = String.format("SELECT CONTAINS_SUBSTR(%s, %s) FROM strings", left, right);
230+
231+
assertSqlRoundTrip(query);
232+
}
233+
234+
@ParameterizedTest
235+
@CsvSource({"c16, c16", "c16, vc32", "c16, vc", "vc32, vc32", "vc32, vc", "vc, vc"})
236+
void testPosition(String left, String right) throws Exception {
237+
238+
String query = String.format("SELECT POSITION(%s IN %s) > 0 FROM strings", left, right);
239+
240+
assertSqlRoundTrip(query);
241+
}
242+
243+
@ParameterizedTest
244+
@CsvSource(
245+
value = {"'start', vc", "vc, 'end'"},
246+
quoteCharacter = '`')
247+
void testPositionWithLiteral(String left, String right) throws Exception {
248+
249+
String query = String.format("SELECT POSITION(%s IN %s) > 0 FROM strings", left, right);
250+
251+
assertSqlRoundTrip(query);
252+
}
253+
254+
@ParameterizedTest
255+
@CsvSource({"c16, c16", "c16, vc32", "c16, vc", "vc32, vc32", "vc32, vc", "vc, vc"})
256+
void testStrpos(String left, String right) throws Exception {
257+
258+
String query = String.format("SELECT STRPOS(%s, %s) > 0 FROM strings", left, right);
259+
260+
assertSqlRoundTrip(query);
261+
}
262+
263+
@ParameterizedTest
264+
@CsvSource(
265+
value = {"'start', vc", "vc, 'end'"},
266+
quoteCharacter = '`')
267+
void testStrposWithLiteral(String left, String right) throws Exception {
268+
269+
String query = String.format("SELECT STRPOS(%s, %s) > 0 FROM strings", left, right);
270+
271+
assertSqlRoundTrip(query);
272+
}
140273
}

0 commit comments

Comments
 (0)