Skip to content

Commit ac0b53e

Browse files
committed
feat(isthmus): add support for scalar subqueries
Signed-off-by: Niels Pardon <par@zurich.ibm.com>
1 parent dfde97a commit ac0b53e

File tree

4 files changed

+109
-20
lines changed

4 files changed

+109
-20
lines changed

isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import io.substrait.expression.Expression;
77
import io.substrait.expression.Expression.FailureBehavior;
88
import io.substrait.expression.Expression.ScalarSubquery;
9+
import io.substrait.expression.Expression.SetPredicate;
910
import io.substrait.expression.Expression.SingleOrList;
1011
import io.substrait.expression.Expression.Switch;
1112
import io.substrait.expression.FieldReference;
@@ -545,4 +546,18 @@ public RexNode visit(ScalarSubquery expr) throws RuntimeException {
545546
RelNode inputRelnode = expr.input().accept(relNodeConverter);
546547
return RexSubQuery.scalar(inputRelnode);
547548
}
549+
550+
@Override
551+
public RexNode visit(SetPredicate expr) throws RuntimeException {
552+
RelNode inputRelnode = expr.tuples().accept(relNodeConverter);
553+
switch (expr.predicateOp()) {
554+
case PREDICATE_OP_EXISTS:
555+
return RexSubQuery.exists(inputRelnode);
556+
case PREDICATE_OP_UNIQUE:
557+
return RexSubQuery.unique(inputRelnode);
558+
case PREDICATE_OP_UNSPECIFIED:
559+
default:
560+
return super.visit(expr);
561+
}
562+
}
548563
}

isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
public class PlanTestBase {
3737
protected final SimpleExtension.ExtensionCollection extensions = SimpleExtension.loadDefaults();
38-
protected final RelCreator creator = new RelCreator();
38+
protected final RelCreator creator = new RelCreator(tpchSchemaCreateStatements());
3939
protected final RelBuilder builder = creator.createRelBuilder();
4040
protected final RexBuilder rex = creator.rex();
4141
protected final RelDataTypeFactory typeFactory = creator.typeFactory();
@@ -47,11 +47,16 @@ public static String asString(String resource) throws IOException {
4747
return Resources.toString(Resources.getResource(resource), Charsets.UTF_8);
4848
}
4949

50-
public static List<String> tpchSchemaCreateStatements() throws IOException {
51-
String[] values = asString("tpch/schema.sql").split(";");
52-
return Arrays.stream(values)
53-
.filter(t -> !t.trim().isBlank())
54-
.collect(java.util.stream.Collectors.toList());
50+
public static List<String> tpchSchemaCreateStatements() {
51+
String[] values;
52+
try {
53+
values = asString("tpch/schema.sql").split(";");
54+
return Arrays.stream(values)
55+
.filter(t -> !t.trim().isBlank())
56+
.collect(java.util.stream.Collectors.toList());
57+
} catch (IOException e) {
58+
throw new RuntimeException(e);
59+
}
5560
}
5661

5762
protected Plan assertProtoPlanRoundrip(String query) throws IOException, SqlParseException {

isthmus/src/test/java/io/substrait/isthmus/RelCreator.java

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.substrait.isthmus;
22

33
import java.util.Arrays;
4+
import java.util.List;
45
import org.apache.calcite.config.CalciteConnectionConfig;
56
import org.apache.calcite.config.CalciteConnectionProperty;
67
import org.apache.calcite.jdbc.CalciteSchema;
@@ -25,19 +26,37 @@
2526
import org.apache.calcite.sql2rel.SqlToRelConverter;
2627
import org.apache.calcite.sql2rel.StandardConvertletTable;
2728
import org.apache.calcite.tools.RelBuilder;
29+
import org.apache.calcite.util.Pair;
2830

29-
public class RelCreator {
31+
public class RelCreator extends SqlConverterBase {
3032
static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(RelCreator.class);
3133

3234
private RelOptCluster cluster;
3335
private CalciteCatalogReader catalog;
36+
private SqlValidator validator;
3437

3538
public RelCreator() {
39+
super(SqlConverterBase.FEATURES_DEFAULT);
3640
CalciteSchema schema = CalciteSchema.createRootSchema(false);
3741
RelDataTypeFactory factory = new JavaTypeFactoryImpl(SubstraitTypeSystem.TYPE_SYSTEM);
3842
CalciteConnectionConfig config =
3943
CalciteConnectionConfig.DEFAULT.set(CalciteConnectionProperty.CASE_SENSITIVE, "false");
40-
catalog = new CalciteCatalogReader(schema, Arrays.asList(), factory, config);
44+
this.validator = new Validator(catalog, cluster.getTypeFactory(), SqlValidator.Config.DEFAULT);
45+
this.catalog = new CalciteCatalogReader(schema, Arrays.asList(), factory, config);
46+
VolcanoPlanner planner = new VolcanoPlanner(RelOptCostImpl.FACTORY, Contexts.EMPTY_CONTEXT);
47+
cluster = RelOptCluster.create(planner, new RexBuilder(factory));
48+
}
49+
50+
public RelCreator(List<String> creates) {
51+
super(SqlConverterBase.FEATURES_DEFAULT);
52+
RelDataTypeFactory factory = new JavaTypeFactoryImpl(SubstraitTypeSystem.TYPE_SYSTEM);
53+
try {
54+
Pair<SqlValidator, CalciteCatalogReader> pair = this.registerCreateTables(creates);
55+
this.validator = pair.left;
56+
this.catalog = pair.right;
57+
} catch (SqlParseException e) {
58+
throw new RuntimeException(e);
59+
}
4160
VolcanoPlanner planner = new VolcanoPlanner(RelOptCostImpl.FACTORY, Contexts.EMPTY_CONTEXT);
4261
cluster = RelOptCluster.create(planner, new RexBuilder(factory));
4362
}
@@ -51,8 +70,6 @@ public RelRoot parse(String sql) {
5170
() ->
5271
new RelMetadataQuery(
5372
new ProxyingMetadataHandlerProvider(DefaultRelMetadataProvider.INSTANCE)));
54-
SqlValidator validator =
55-
new Validator(catalog, cluster.getTypeFactory(), SqlValidator.Config.DEFAULT);
5673

5774
SqlToRelConverter.Config converterConfig =
5875
SqlToRelConverter.config().withTrimUnusedFields(true).withExpand(false);

isthmus/src/test/java/io/substrait/isthmus/SubstraitExpressionConverterTest.java

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,7 @@ public void switchExpression() {
5353

5454
@Test
5555
public void scalarSubQuery() {
56-
Rel subQueryRel =
57-
b.project(
58-
input -> List.of(b.fieldReference(input, 0)),
59-
Remap.of(List.of(3)),
60-
b.filter(
61-
input ->
62-
b.equal(
63-
b.fieldReference(input, 2),
64-
Expression.StrLiteral.builder().nullable(false).value("EUROPE").build()),
65-
commonTable));
56+
Rel subQueryRel = createSubQueryRel();
6657

6758
Expression.ScalarSubquery expr =
6859
Expression.ScalarSubquery.builder()
@@ -81,6 +72,67 @@ public void scalarSubQuery() {
8172
assertEquals(SqlKind.SCALAR_QUERY, calciteProjectExpr.get(0).getKind());
8273
}
8374

75+
@Test
76+
public void existsSetPredicate() {
77+
Rel subQueryRel = createSubQueryRel();
78+
79+
Expression.SetPredicate expr =
80+
Expression.SetPredicate.builder()
81+
.predicateOp(Expression.PredicateOp.PREDICATE_OP_EXISTS)
82+
.tuples(subQueryRel)
83+
.build();
84+
85+
Project query = b.project(input -> List.of(expr), b.emptyScan());
86+
87+
SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory);
88+
RelNode calciteRel = substraitToCalcite.convert(query);
89+
90+
assertInstanceOf(LogicalProject.class, calciteRel);
91+
List<RexNode> calciteProjectExpr = ((LogicalProject) calciteRel).getProjects();
92+
assertEquals(1, calciteProjectExpr.size());
93+
assertEquals(SqlKind.EXISTS, calciteProjectExpr.get(0).getKind());
94+
}
95+
96+
@Test
97+
public void uniqueSetPredicate() {
98+
Rel subQueryRel = createSubQueryRel();
99+
100+
Expression.SetPredicate expr =
101+
Expression.SetPredicate.builder()
102+
.predicateOp(Expression.PredicateOp.PREDICATE_OP_UNIQUE)
103+
.tuples(subQueryRel)
104+
.build();
105+
106+
Project query = b.project(input -> List.of(expr), b.emptyScan());
107+
108+
SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory);
109+
RelNode calciteRel = substraitToCalcite.convert(query);
110+
111+
assertInstanceOf(LogicalProject.class, calciteRel);
112+
List<RexNode> calciteProjectExpr = ((LogicalProject) calciteRel).getProjects();
113+
assertEquals(1, calciteProjectExpr.size());
114+
assertEquals(SqlKind.UNIQUE, calciteProjectExpr.get(0).getKind());
115+
}
116+
117+
/**
118+
* Creates a Substrait {@link Rel} equivalent to the following SQL query:
119+
*
120+
* <p>select a from example where c = 'EUROPE'
121+
*
122+
* @return the Substrait {@link Rel} equivalent of the above SQL query
123+
*/
124+
Rel createSubQueryRel() {
125+
return b.project(
126+
input -> List.of(b.fieldReference(input, 0)),
127+
Remap.of(List.of(3)),
128+
b.filter(
129+
input ->
130+
b.equal(
131+
b.fieldReference(input, 2),
132+
Expression.StrLiteral.builder().nullable(false).value("EUROPE").build()),
133+
commonTable));
134+
}
135+
84136
void assertTypeMatch(RelDataType actual, Type expected) {
85137
Type type = TypeConverter.DEFAULT.toSubstrait(actual);
86138
assertEquals(expected, type);

0 commit comments

Comments
 (0)