2525import java .util .List ;
2626import java .util .Locale ;
2727import java .util .Map ;
28+ import java .util .Objects ;
2829import java .util .Optional ;
2930import java .util .Set ;
3031import java .util .function .Function ;
@@ -149,7 +150,6 @@ protected class FunctionFinder {
149150 private final SqlOperator operator ;
150151 private final List <F > functions ;
151152 private final Map <String , F > directMap ;
152- private final SignatureMatcher <F > matcher ;
153153 private final Optional <SingularArgumentMatcher <F >> singularInputType ;
154154 private final Util .IntRange argRange ;
155155
@@ -161,7 +161,6 @@ public FunctionFinder(String name, SqlOperator operator, List<F> functions) {
161161 Util .IntRange .of (
162162 functions .stream ().mapToInt (t -> t .getRange ().getStartInclusive ()).min ().getAsInt (),
163163 functions .stream ().mapToInt (t -> t .getRange ().getEndExclusive ()).max ().getAsInt ());
164- this .matcher = getSignatureMatcher (operator , functions );
165164 this .singularInputType = getSingularInputType (functions );
166165 var directMap = ImmutableMap .<String , F >builder ();
167166 for (var func : functions ) {
@@ -178,21 +177,18 @@ public boolean allowedArgCount(int count) {
178177 return argRange .within (count );
179178 }
180179
181- private static <F extends SimpleExtension .Function > SignatureMatcher <F > getSignatureMatcher (
182- SqlOperator operator , List <F > functions ) {
183- return (inputTypes , outputType ) -> {
184- for (F function : functions ) {
185- List <SimpleExtension .Argument > args = function .requiredArguments ();
186- // Make sure that arguments & return are within bounds and match the types
187- if (function .returnType () instanceof ParameterizedType
188- && isMatch (outputType , (ParameterizedType ) function .returnType ())
189- && inputTypesSatisfyDefinedArguments (inputTypes , args )) {
190- return Optional .of (function );
191- }
180+ private Optional <F > signatureMatch (List <Type > inputTypes , Type outputType ) {
181+ for (F function : functions ) {
182+ List <SimpleExtension .Argument > args = function .requiredArguments ();
183+ // Make sure that arguments & return are within bounds and match the types
184+ if (function .returnType () instanceof ParameterizedType
185+ && isMatch (outputType , (ParameterizedType ) function .returnType ())
186+ && inputTypesMatchDefinedArguments (inputTypes , args )) {
187+ return Optional .of (function );
192188 }
189+ }
193190
194- return Optional .empty ();
195- };
191+ return Optional .empty ();
196192 }
197193
198194 /**
@@ -208,7 +204,7 @@ && inputTypesSatisfyDefinedArguments(inputTypes, args)) {
208204 * @param args expected arguments as defined in a {@link SimpleExtension.Function}
209205 * @return true if the {@code inputTypes} satisfy the {@code args}, false otherwise
210206 */
211- private static boolean inputTypesSatisfyDefinedArguments (
207+ private static boolean inputTypesMatchDefinedArguments (
212208 List <Type > inputTypes , List <SimpleExtension .Argument > args ) {
213209
214210 Map <String , Set <Type >> wildcardToType = new HashMap <>();
@@ -318,7 +314,7 @@ private Stream<String> matchKeys(List<RexNode> rexOperands, List<String> opTypes
318314
319315 assert (rexOperands .size () == opTypes .size ());
320316
321- if (rexOperands .size () == 0 ) {
317+ if (rexOperands .isEmpty () ) {
322318 return Stream .of ("" );
323319 } else {
324320 List <List <String >> argTypeLists =
@@ -357,13 +353,12 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
357353 // try to do a direct match
358354 List <String > typeStrings =
359355 opTypes .stream ().map (t -> t .accept (ToTypeString .INSTANCE )).collect (Collectors .toList ());
360- Stream <String > possibleKeys =
361- matchKeys (call .getOperands ().collect (Collectors .toList ()), typeStrings );
356+ Stream <String > possibleKeys = matchKeys (operandsList , typeStrings );
362357
363358 Optional <String > directMatchKey =
364359 possibleKeys
365360 .map (argList -> name + ":" + argList )
366- .filter (k -> directMap . containsKey ( k ) )
361+ .filter (directMap :: containsKey )
367362 .findFirst ();
368363
369364 if (directMatchKey .isPresent ()) {
@@ -376,14 +371,13 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
376371 RexNode r = operandsList .get (i );
377372 Expression o = operands .get (i );
378373 if (EnumConverter .isEnumValue (r )) {
379- return EnumConverter .fromRex (variant , (RexLiteral ) r , i )
380- .orElseGet (() -> null );
374+ return EnumConverter .fromRex (variant , (RexLiteral ) r , i ).orElse (null );
381375 } else {
382376 return o ;
383377 }
384378 })
385379 .collect (Collectors .toList ());
386- boolean allArgsMapped = funcArgs .stream ().filter (e -> e == null ).findFirst ().isEmpty ();
380+ boolean allArgsMapped = funcArgs .stream ().filter (Objects :: isNull ).findFirst ().isEmpty ();
387381 if (allArgsMapped ) {
388382 return Optional .of (generateBinding (call , variant , funcArgs , outputType ));
389383 } else {
@@ -413,53 +407,35 @@ private Optional<T> matchByLeastRestrictive(
413407 return Optional .empty ();
414408 }
415409 Type type = typeConverter .toSubstrait (leastRestrictive );
416- var out = singularInputType .get ().tryMatch (type , outputType );
417-
418- if (out .isPresent ()) {
419- var declaration = out .get ();
420- var coercedArgs = coerceArguments (operands , type );
421- declaration .validateOutputType (coercedArgs , outputType );
422- return Optional .of (
423- generateBinding (
424- call ,
425- out .get (),
426- coercedArgs .stream ().map (FunctionArg .class ::cast ).collect (Collectors .toList ()),
427- outputType ));
428- }
429- return Optional .empty ();
410+ var out = singularInputType .orElseThrow ().tryMatch (type , outputType );
411+
412+ return out .map (
413+ declaration -> {
414+ var coercedArgs = coerceArguments (operands , type );
415+ declaration .validateOutputType (coercedArgs , outputType );
416+ return generateBinding (call , out .get (), coercedArgs , outputType );
417+ });
430418 }
431419
432- private Optional <T > matchCoerced (C call , Type outputType , List <Expression > operands ) {
433-
420+ private Optional <T > matchCoerced (C call , Type outputType , List <Expression > expressions ) {
434421 // Convert the operands to the proper Substrait type
435- List <Type > allTypes =
422+ List <Type > operandTypes =
436423 call .getOperands ()
437424 .map (RexNode ::getType )
438425 .map (typeConverter ::toSubstrait )
439426 .collect (Collectors .toList ());
440427
441- // See if all the input types match the function
442- Optional <F > matchFunction = this .matcher .tryMatch (allTypes , outputType );
443- if (matchFunction .isPresent ()) {
444- List <Expression > coerced =
445- Streams .zip (
446- operands .stream (),
447- call .getOperands (),
448- (a , b ) -> {
449- Type type = typeConverter .toSubstrait (b .getType ());
450- return coerceArgument (a , type );
451- })
452- .collect (Collectors .toList ());
453-
454- return Optional .of (
455- generateBinding (
456- call ,
457- matchFunction .get (),
458- coerced .stream ().map (FunctionArg .class ::cast ).collect (Collectors .toList ()),
459- outputType ));
428+ // See if all the input types can be made to match the function
429+ Optional <F > matchFunction = signatureMatch (operandTypes , outputType );
430+ if (matchFunction .isEmpty ()) {
431+ return Optional .empty ();
460432 }
461433
462- return Optional .empty ();
434+ var coercedArgs =
435+ Streams .zip (
436+ expressions .stream (), operandTypes .stream (), FunctionConverter ::coerceArgument )
437+ .collect (Collectors .toList ());
438+ return Optional .of (generateBinding (call , matchFunction .get (), coercedArgs , outputType ));
463439 }
464440
465441 protected String getName () {
@@ -481,56 +457,30 @@ public interface GenericCall {
481457 * Coerced types according to an expected output type. Coercion is only done for type mismatches,
482458 * not for nullability or parameter mismatches.
483459 */
484- private static List <Expression > coerceArguments (List <Expression > arguments , Type type ) {
485- return arguments .stream ().map (a -> coerceArgument (a , type )).collect (Collectors .toList ());
460+ private static List <Expression > coerceArguments (List <Expression > arguments , Type targetType ) {
461+ return arguments .stream ().map (a -> coerceArgument (a , targetType )).collect (Collectors .toList ());
486462 }
487463
488464 private static Expression coerceArgument (Expression argument , Type type ) {
489- var typeMatches = isMatch (type , argument .getType ());
490- if (!typeMatches ) {
491- return ExpressionCreator .cast (type , argument , Expression .FailureBehavior .THROW_EXCEPTION );
465+ if (isMatch (type , argument .getType ())) {
466+ return argument ;
492467 }
493- return argument ;
468+
469+ return ExpressionCreator .cast (type , argument , Expression .FailureBehavior .THROW_EXCEPTION );
494470 }
495471
496472 protected abstract T generateBinding (
497- C call , F function , List <FunctionArg > arguments , Type outputType );
473+ C call , F function , List <? extends FunctionArg > arguments , Type outputType );
498474
499- public interface SingularArgumentMatcher <F > {
475+ @ FunctionalInterface
476+ private interface SingularArgumentMatcher <F > {
500477 Optional <F > tryMatch (Type type , Type outputType );
501478 }
502479
503- public interface SignatureMatcher <F > {
504- Optional <F > tryMatch (List <Type > types , Type outputType );
505- }
506-
507- private static SignatureMatcher chainedSignature (SignatureMatcher ... matchers ) {
508- return switch (matchers .length ) {
509- case 0 -> (types , outputType ) -> Optional .empty ();
510- case 1 -> matchers [0 ];
511- default -> (types , outputType ) -> {
512- for (SignatureMatcher m : matchers ) {
513- var t = m .tryMatch (types , outputType );
514- if (t .isPresent ()) {
515- return t ;
516- }
517- }
518- return Optional .empty ();
519- };
520- };
521- }
522-
523- private static boolean isMatch (Type inputType , ParameterizedType type ) {
524- if (type .isWildcard ()) {
525- return true ;
526- }
527- return inputType .accept (new IgnoreNullableAndParameters (type ));
528- }
529-
530- private static boolean isMatch (ParameterizedType inputType , ParameterizedType type ) {
531- if (type .isWildcard ()) {
480+ private static boolean isMatch (ParameterizedType actualType , ParameterizedType targetType ) {
481+ if (targetType .isWildcard ()) {
532482 return true ;
533483 }
534- return inputType .accept (new IgnoreNullableAndParameters (type ));
484+ return actualType .accept (new IgnoreNullableAndParameters (targetType ));
535485 }
536486}
0 commit comments