Skip to content

Commit d5f059f

Browse files
fix(isthmus): concat of different string types
The type matching used when mapping Calcite to Substrait expressions required parameters and return type of the concat function to be either all varchar or all string. Any mixing of types caused a failure. Signed-off-by: Mark S. Lewis <[email protected]>
1 parent fd74922 commit d5f059f

File tree

8 files changed

+88
-126
lines changed

8 files changed

+88
-126
lines changed

isthmus/src/main/java/io/substrait/isthmus/SubstraitTypeSystem.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ public int getMaxNumericPrecision() {
3939
return 38;
4040
}
4141

42+
@Override
43+
public boolean shouldConvertRaggedUnionTypesToVarying() {
44+
return true;
45+
}
46+
4247
public static RelDataTypeFactory createTypeFactory() {
4348
return new JavaTypeFactoryImpl(TYPE_SYSTEM);
4449
}

isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public AggregateFunctionConverter(
5252
protected AggregateFunctionInvocation generateBinding(
5353
WrappedAggregateCall call,
5454
SimpleExtension.AggregateFunctionVariant function,
55-
List<FunctionArg> arguments,
55+
List<? extends FunctionArg> arguments,
5656
Type outputType) {
5757
AggregateCall agg = call.getUnderlying();
5858

isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java

Lines changed: 51 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.util.List;
2626
import java.util.Locale;
2727
import java.util.Map;
28+
import java.util.Objects;
2829
import java.util.Optional;
2930
import java.util.Set;
3031
import java.util.function.Function;
@@ -148,7 +149,6 @@ protected class FunctionFinder {
148149
private final SqlOperator operator;
149150
private final List<F> functions;
150151
private final Map<String, F> directMap;
151-
private final SignatureMatcher<F> matcher;
152152
private final Optional<SingularArgumentMatcher<F>> singularInputType;
153153
private final Util.IntRange argRange;
154154

@@ -160,7 +160,6 @@ public FunctionFinder(String name, SqlOperator operator, List<F> functions) {
160160
Util.IntRange.of(
161161
functions.stream().mapToInt(t -> t.getRange().getStartInclusive()).min().getAsInt(),
162162
functions.stream().mapToInt(t -> t.getRange().getEndExclusive()).max().getAsInt());
163-
this.matcher = getSignatureMatcher(operator, functions);
164163
this.singularInputType = getSingularInputType(functions);
165164
var directMap = ImmutableMap.<String, F>builder();
166165
for (var func : functions) {
@@ -177,21 +176,18 @@ public boolean allowedArgCount(int count) {
177176
return argRange.within(count);
178177
}
179178

180-
private static <F extends SimpleExtension.Function> SignatureMatcher<F> getSignatureMatcher(
181-
SqlOperator operator, List<F> functions) {
182-
return (inputTypes, outputType) -> {
183-
for (F function : functions) {
184-
List<SimpleExtension.Argument> args = function.requiredArguments();
185-
// Make sure that arguments & return are within bounds and match the types
186-
if (function.returnType() instanceof ParameterizedType
187-
&& isMatch(outputType, (ParameterizedType) function.returnType())
188-
&& inputTypesSatisfyDefinedArguments(inputTypes, args)) {
189-
return Optional.of(function);
190-
}
179+
private Optional<F> signatureMatch(List<Type> inputTypes, Type outputType) {
180+
for (F function : functions) {
181+
List<SimpleExtension.Argument> args = function.requiredArguments();
182+
// Make sure that arguments & return are within bounds and match the types
183+
if (function.returnType() instanceof ParameterizedType
184+
&& isMatch(outputType, (ParameterizedType) function.returnType())
185+
&& inputTypesMatchDefinedArguments(inputTypes, args)) {
186+
return Optional.of(function);
191187
}
188+
}
192189

193-
return Optional.empty();
194-
};
190+
return Optional.empty();
195191
}
196192

197193
/**
@@ -207,7 +203,7 @@ && inputTypesSatisfyDefinedArguments(inputTypes, args)) {
207203
* @param args expected arguments as defined in a {@link SimpleExtension.Function}
208204
* @return true if the {@code inputTypes} satisfy the {@code args}, false otherwise
209205
*/
210-
private static boolean inputTypesSatisfyDefinedArguments(
206+
private static boolean inputTypesMatchDefinedArguments(
211207
List<Type> inputTypes, List<SimpleExtension.Argument> args) {
212208

213209
Map<String, Set<Type>> wildcardToType = new HashMap<>();
@@ -317,7 +313,7 @@ private Stream<String> matchKeys(List<RexNode> rexOperands, List<String> opTypes
317313

318314
assert (rexOperands.size() == opTypes.size());
319315

320-
if (rexOperands.size() == 0) {
316+
if (rexOperands.isEmpty()) {
321317
return Stream.of("");
322318
} else {
323319
List<List<String>> argTypeLists =
@@ -346,23 +342,20 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
346342
* Once a FunctionVariant is resolved we can map the String Literal
347343
* to a EnumArg.
348344
*/
349-
var operands =
350-
call.getOperands().map(topLevelConverter).collect(java.util.stream.Collectors.toList());
351-
var opTypes =
352-
operands.stream().map(Expression::getType).collect(java.util.stream.Collectors.toList());
345+
var operands = call.getOperands().map(topLevelConverter).collect(Collectors.toList());
346+
var opTypes = operands.stream().map(Expression::getType).collect(Collectors.toList());
353347

354348
var outputType = typeConverter.toSubstrait(call.getType());
355349

356350
// try to do a direct match
357351
var typeStrings =
358352
opTypes.stream().map(t -> t.accept(ToTypeString.INSTANCE)).collect(Collectors.toList());
359-
var possibleKeys =
360-
matchKeys(call.getOperands().collect(java.util.stream.Collectors.toList()), typeStrings);
353+
var possibleKeys = matchKeys(call.getOperands().collect(Collectors.toList()), typeStrings);
361354

362355
var directMatchKey =
363356
possibleKeys
364357
.map(argList -> name + ":" + argList)
365-
.filter(k -> directMap.containsKey(k))
358+
.filter(directMap::containsKey)
366359
.findFirst();
367360

368361
if (directMatchKey.isPresent()) {
@@ -375,13 +368,13 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
375368
operands.stream(),
376369
(r, o) -> {
377370
if (EnumConverter.isEnumValue(r)) {
378-
return EnumConverter.fromRex(variant, (RexLiteral) r).orElseGet(() -> null);
371+
return EnumConverter.fromRex(variant, (RexLiteral) r).orElse(null);
379372
} else {
380373
return o;
381374
}
382375
})
383-
.collect(java.util.stream.Collectors.toList());
384-
var allArgsMapped = funcArgs.stream().filter(e -> e == null).findFirst().isEmpty();
376+
.collect(Collectors.toList());
377+
var allArgsMapped = funcArgs.stream().filter(Objects::isNull).findFirst().isEmpty();
385378
if (allArgsMapped) {
386379
return Optional.of(generateBinding(call, variant, funcArgs, outputType));
387380
} else {
@@ -411,53 +404,35 @@ private Optional<T> matchByLeastRestrictive(
411404
return Optional.empty();
412405
}
413406
Type type = typeConverter.toSubstrait(leastRestrictive);
414-
var out = singularInputType.get().tryMatch(type, outputType);
415-
416-
if (out.isPresent()) {
417-
var declaration = out.get();
418-
var coercedArgs = coerceArguments(operands, type);
419-
declaration.validateOutputType(coercedArgs, outputType);
420-
return Optional.of(
421-
generateBinding(
422-
call,
423-
out.get(),
424-
coercedArgs.stream().map(FunctionArg.class::cast).collect(Collectors.toList()),
425-
outputType));
426-
}
427-
return Optional.empty();
407+
var out = singularInputType.orElseThrow().tryMatch(type, outputType);
408+
409+
return out.map(
410+
declaration -> {
411+
var coercedArgs = coerceArguments(operands, type);
412+
declaration.validateOutputType(coercedArgs, outputType);
413+
return generateBinding(call, out.get(), coercedArgs, outputType);
414+
});
428415
}
429416

430-
private Optional<T> matchCoerced(C call, Type outputType, List<Expression> operands) {
431-
417+
private Optional<T> matchCoerced(C call, Type outputType, List<Expression> expressions) {
432418
// Convert the operands to the proper Substrait type
433-
List<Type> allTypes =
419+
List<Type> operandTypes =
434420
call.getOperands()
435421
.map(RexNode::getType)
436422
.map(typeConverter::toSubstrait)
437423
.collect(Collectors.toList());
438424

439-
// See if all the input types match the function
440-
Optional<F> matchFunction = this.matcher.tryMatch(allTypes, outputType);
441-
if (matchFunction.isPresent()) {
442-
List<Expression> coerced =
443-
Streams.zip(
444-
operands.stream(),
445-
call.getOperands(),
446-
(a, b) -> {
447-
Type type = typeConverter.toSubstrait(b.getType());
448-
return coerceArgument(a, type);
449-
})
450-
.collect(Collectors.toList());
451-
452-
return Optional.of(
453-
generateBinding(
454-
call,
455-
matchFunction.get(),
456-
coerced.stream().map(FunctionArg.class::cast).collect(Collectors.toList()),
457-
outputType));
425+
// See if all the input types can be made to match the function
426+
Optional<F> matchFunction = signatureMatch(operandTypes, outputType);
427+
if (matchFunction.isEmpty()) {
428+
return Optional.empty();
458429
}
459430

460-
return Optional.empty();
431+
var coercedArgs =
432+
Streams.zip(
433+
expressions.stream(), operandTypes.stream(), FunctionConverter::coerceArgument)
434+
.collect(Collectors.toList());
435+
return Optional.of(generateBinding(call, matchFunction.get(), coercedArgs, outputType));
461436
}
462437

463438
protected String getName() {
@@ -479,56 +454,30 @@ public interface GenericCall {
479454
* Coerced types according to an expected output type. Coercion is only done for type mismatches,
480455
* not for nullability or parameter mismatches.
481456
*/
482-
private static List<Expression> coerceArguments(List<Expression> arguments, Type type) {
483-
return arguments.stream().map(a -> coerceArgument(a, type)).collect(Collectors.toList());
457+
private static List<Expression> coerceArguments(List<Expression> arguments, Type targetType) {
458+
return arguments.stream().map(a -> coerceArgument(a, targetType)).collect(Collectors.toList());
484459
}
485460

486461
private static Expression coerceArgument(Expression argument, Type type) {
487-
var typeMatches = isMatch(type, argument.getType());
488-
if (!typeMatches) {
489-
return ExpressionCreator.cast(type, argument, Expression.FailureBehavior.THROW_EXCEPTION);
462+
if (isMatch(type, argument.getType())) {
463+
return argument;
490464
}
491-
return argument;
465+
466+
return ExpressionCreator.cast(type, argument, Expression.FailureBehavior.THROW_EXCEPTION);
492467
}
493468

494469
protected abstract T generateBinding(
495-
C call, F function, List<FunctionArg> arguments, Type outputType);
470+
C call, F function, List<? extends FunctionArg> arguments, Type outputType);
496471

497-
public interface SingularArgumentMatcher<F> {
472+
@FunctionalInterface
473+
private interface SingularArgumentMatcher<F> {
498474
Optional<F> tryMatch(Type type, Type outputType);
499475
}
500476

501-
public interface SignatureMatcher<F> {
502-
Optional<F> tryMatch(List<Type> types, Type outputType);
503-
}
504-
505-
private static SignatureMatcher chainedSignature(SignatureMatcher... matchers) {
506-
return switch (matchers.length) {
507-
case 0 -> (types, outputType) -> Optional.empty();
508-
case 1 -> matchers[0];
509-
default -> (types, outputType) -> {
510-
for (SignatureMatcher m : matchers) {
511-
var t = m.tryMatch(types, outputType);
512-
if (t.isPresent()) {
513-
return t;
514-
}
515-
}
516-
return Optional.empty();
517-
};
518-
};
519-
}
520-
521-
private static boolean isMatch(Type inputType, ParameterizedType type) {
522-
if (type.isWildcard()) {
523-
return true;
524-
}
525-
return inputType.accept(new IgnoreNullableAndParameters(type));
526-
}
527-
528-
private static boolean isMatch(ParameterizedType inputType, ParameterizedType type) {
529-
if (type.isWildcard()) {
477+
private static boolean isMatch(ParameterizedType actualType, ParameterizedType targetType) {
478+
if (targetType.isWildcard()) {
530479
return true;
531480
}
532-
return inputType.accept(new IgnoreNullableAndParameters(type));
481+
return actualType.accept(new IgnoreNullableAndParameters(targetType));
533482
}
534483
}

isthmus/src/main/java/io/substrait/isthmus/expression/ScalarFunctionConverter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public Optional<Expression> convert(
6060
protected Expression generateBinding(
6161
WrappedScalarCall call,
6262
SimpleExtension.ScalarFunctionVariant function,
63-
List<FunctionArg> arguments,
63+
List<? extends FunctionArg> arguments,
6464
Type outputType) {
6565
return Expression.ScalarFunctionInvocation.builder()
6666
.outputType(outputType)

isthmus/src/main/java/io/substrait/isthmus/expression/WindowFunctionConverter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public WindowFunctionConverter(
5252
protected Expression.WindowFunctionInvocation generateBinding(
5353
WrappedWindowCall call,
5454
SimpleExtension.WindowFunctionVariant function,
55-
List<FunctionArg> arguments,
55+
List<? extends FunctionArg> arguments,
5656
Type outputType) {
5757
RexOver over = call.over;
5858
RexWindow window = over.getWindow();

isthmus/src/main/java/io/substrait/isthmus/expression/WindowRelFunctionConverter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public WindowRelFunctionConverter(
5151
protected ConsistentPartitionWindow.WindowRelFunctionInvocation generateBinding(
5252
WrappedWindowRelCall call,
5353
SimpleExtension.WindowFunctionVariant function,
54-
List<FunctionArg> arguments,
54+
List<? extends FunctionArg> arguments,
5555
Type outputType) {
5656
Window.RexWinAggCall over = call.getWinAggCall();
5757

isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,21 @@ public void inPredicate() throws IOException, SqlParseException {
6565
+ "(SELECT L_SUPPKEY from LINEITEM where L_SUPPKEY < L_ORDERKEY)");
6666
}
6767

68+
@Test
69+
public void concatStringLiteralAndVarchar() throws Exception {
70+
assertProtoPlanRoundrip("select 'part_'||P_NAME from PART");
71+
}
72+
73+
@Test
74+
public void concatCharAndVarchar() throws Exception {
75+
assertProtoPlanRoundrip("select P_BRAND||P_NAME from PART");
76+
}
77+
78+
@Test
79+
public void concatStringLiteralAndChar() throws Exception {
80+
assertProtoPlanRoundrip("select 'brand_'||P_BRAND from PART");
81+
}
82+
6883
@Test
6984
public void singleOrList() {
7085
Expression singleOrList = b.singleOrList(b.fieldReference(commonTable, 0), b.i32(5), b.i32(10));
Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package io.substrait.isthmus;
22

33
import com.google.protobuf.util.JsonFormat;
4+
import java.util.Set;
5+
import java.util.stream.IntStream;
46
import org.apache.calcite.adapter.tpcds.TpcdsSchema;
57
import org.junit.jupiter.params.ParameterizedTest;
6-
import org.junit.jupiter.params.provider.ValueSource;
8+
import org.junit.jupiter.params.provider.MethodSource;
79

810
/**
911
*
@@ -27,33 +29,24 @@
2729
*/
2830
public class TpcdsQueryNoValidation extends PlanTestBase {
2931

32+
static final Set<Integer> EXCLUDED =
33+
Set.of(2, 9, 12, 20, 27, 36, 47, 51, 53, 57, 63, 70, 86, 89, 98);
34+
35+
static IntStream testCases() {
36+
return IntStream.rangeClosed(1, 99).filter(n -> !EXCLUDED.contains(n));
37+
}
38+
3039
/**
3140
* This test only validates that generating substrait plans for TPC-DS queries does not fail. As
3241
* of now this test does not validate correctness of the generated plan
3342
*/
34-
private void testQuery(int i) throws Exception {
43+
@ParameterizedTest
44+
@MethodSource("testCases")
45+
void testQuery(int i) throws Exception {
3546
SqlToSubstrait s = new SqlToSubstrait();
3647
TpcdsSchema schema = new TpcdsSchema(1.0);
3748
String sql = asString(String.format("tpcds/queries/%02d.sql", i));
3849
var plan = s.execute(sql, "tpcds", schema);
3950
System.out.println(JsonFormat.printer().print(plan));
4051
}
41-
42-
@ParameterizedTest
43-
@ValueSource(
44-
ints = {
45-
1, 3, 4, 6, 7, 8, 10, 11, 13, 14, 15, 16, 17, 18, 19, 21, 22, 23, 24, 25, 26, 28, 29, 30,
46-
31, 32, 33, 34, 35, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 48, 49, 50, 52, 54, 55, 56, 58,
47-
59, 60, 61, 62, 64, 65, 67, 68, 69, 71, 72, 73, 74, 75, 76, 77, 78, 79, 81, 82, 83, 85, 87,
48-
88, 90, 92, 93, 94, 95, 96, 97, 99
49-
})
50-
public void tpcdsSuccess(int query) throws Exception {
51-
testQuery(query);
52-
}
53-
54-
@ParameterizedTest
55-
@ValueSource(ints = {2, 5, 9, 12, 20, 27, 36, 47, 51, 53, 57, 63, 66, 70, 80, 84, 86, 89, 91, 98})
56-
public void tpcdsFailure(int query) throws Exception {
57-
// testQuery(query);
58-
}
5952
}

0 commit comments

Comments
 (0)