2525import java .util .List ;
2626import java .util .Locale ;
2727import java .util .Map ;
28+ import java .util .Objects ;
2829import java .util .Optional ;
2930import java .util .Set ;
3031import java .util .function .Function ;
@@ -148,7 +149,6 @@ protected class FunctionFinder {
148149 private final SqlOperator operator ;
149150 private final List <F > functions ;
150151 private final Map <String , F > directMap ;
151- private final SignatureMatcher <F > matcher ;
152152 private final Optional <SingularArgumentMatcher <F >> singularInputType ;
153153 private final Util .IntRange argRange ;
154154
@@ -160,7 +160,6 @@ public FunctionFinder(String name, SqlOperator operator, List<F> functions) {
160160 Util .IntRange .of (
161161 functions .stream ().mapToInt (t -> t .getRange ().getStartInclusive ()).min ().getAsInt (),
162162 functions .stream ().mapToInt (t -> t .getRange ().getEndExclusive ()).max ().getAsInt ());
163- this .matcher = getSignatureMatcher (operator , functions );
164163 this .singularInputType = getSingularInputType (functions );
165164 var directMap = ImmutableMap .<String , F >builder ();
166165 for (var func : functions ) {
@@ -177,21 +176,18 @@ public boolean allowedArgCount(int count) {
177176 return argRange .within (count );
178177 }
179178
180- private static <F extends SimpleExtension .Function > SignatureMatcher <F > getSignatureMatcher (
181- SqlOperator operator , List <F > functions ) {
182- return (inputTypes , outputType ) -> {
183- for (F function : functions ) {
184- List <SimpleExtension .Argument > args = function .requiredArguments ();
185- // Make sure that arguments & return are within bounds and match the types
186- if (function .returnType () instanceof ParameterizedType
187- && isMatch (outputType , (ParameterizedType ) function .returnType ())
188- && inputTypesSatisfyDefinedArguments (inputTypes , args )) {
189- return Optional .of (function );
190- }
179+ private Optional <F > signatureMatch (List <Type > inputTypes , Type outputType ) {
180+ for (F function : functions ) {
181+ List <SimpleExtension .Argument > args = function .requiredArguments ();
182+ // Make sure that arguments & return are within bounds and match the types
183+ if (function .returnType () instanceof ParameterizedType
184+ && isMatch (outputType , (ParameterizedType ) function .returnType ())
185+ && inputTypesMatchDefinedArguments (inputTypes , args )) {
186+ return Optional .of (function );
191187 }
188+ }
192189
193- return Optional .empty ();
194- };
190+ return Optional .empty ();
195191 }
196192
197193 /**
@@ -207,7 +203,7 @@ && inputTypesSatisfyDefinedArguments(inputTypes, args)) {
207203 * @param args expected arguments as defined in a {@link SimpleExtension.Function}
208204 * @return true if the {@code inputTypes} satisfy the {@code args}, false otherwise
209205 */
210- private static boolean inputTypesSatisfyDefinedArguments (
206+ private static boolean inputTypesMatchDefinedArguments (
211207 List <Type > inputTypes , List <SimpleExtension .Argument > args ) {
212208
213209 Map <String , Set <Type >> wildcardToType = new HashMap <>();
@@ -317,7 +313,7 @@ private Stream<String> matchKeys(List<RexNode> rexOperands, List<String> opTypes
317313
318314 assert (rexOperands .size () == opTypes .size ());
319315
320- if (rexOperands .size () == 0 ) {
316+ if (rexOperands .isEmpty () ) {
321317 return Stream .of ("" );
322318 } else {
323319 List <List <String >> argTypeLists =
@@ -346,23 +342,20 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
346342 * Once a FunctionVariant is resolved we can map the String Literal
347343 * to a EnumArg.
348344 */
349- var operands =
350- call .getOperands ().map (topLevelConverter ).collect (java .util .stream .Collectors .toList ());
351- var opTypes =
352- operands .stream ().map (Expression ::getType ).collect (java .util .stream .Collectors .toList ());
345+ var operands = call .getOperands ().map (topLevelConverter ).collect (Collectors .toList ());
346+ var opTypes = operands .stream ().map (Expression ::getType ).collect (Collectors .toList ());
353347
354348 var outputType = typeConverter .toSubstrait (call .getType ());
355349
356350 // try to do a direct match
357351 var typeStrings =
358352 opTypes .stream ().map (t -> t .accept (ToTypeString .INSTANCE )).collect (Collectors .toList ());
359- var possibleKeys =
360- matchKeys (call .getOperands ().collect (java .util .stream .Collectors .toList ()), typeStrings );
353+ var possibleKeys = matchKeys (call .getOperands ().collect (Collectors .toList ()), typeStrings );
361354
362355 var directMatchKey =
363356 possibleKeys
364357 .map (argList -> name + ":" + argList )
365- .filter (k -> directMap . containsKey ( k ) )
358+ .filter (directMap :: containsKey )
366359 .findFirst ();
367360
368361 if (directMatchKey .isPresent ()) {
@@ -375,13 +368,13 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
375368 operands .stream (),
376369 (r , o ) -> {
377370 if (EnumConverter .isEnumValue (r )) {
378- return EnumConverter .fromRex (variant , (RexLiteral ) r ).orElseGet (() -> null );
371+ return EnumConverter .fromRex (variant , (RexLiteral ) r ).orElse ( null );
379372 } else {
380373 return o ;
381374 }
382375 })
383- .collect (java . util . stream . Collectors .toList ());
384- var allArgsMapped = funcArgs .stream ().filter (e -> e == null ).findFirst ().isEmpty ();
376+ .collect (Collectors .toList ());
377+ var allArgsMapped = funcArgs .stream ().filter (Objects :: isNull ).findFirst ().isEmpty ();
385378 if (allArgsMapped ) {
386379 return Optional .of (generateBinding (call , variant , funcArgs , outputType ));
387380 } else {
@@ -411,53 +404,35 @@ private Optional<T> matchByLeastRestrictive(
411404 return Optional .empty ();
412405 }
413406 Type type = typeConverter .toSubstrait (leastRestrictive );
414- var out = singularInputType .get ().tryMatch (type , outputType );
415-
416- if (out .isPresent ()) {
417- var declaration = out .get ();
418- var coercedArgs = coerceArguments (operands , type );
419- declaration .validateOutputType (coercedArgs , outputType );
420- return Optional .of (
421- generateBinding (
422- call ,
423- out .get (),
424- coercedArgs .stream ().map (FunctionArg .class ::cast ).collect (Collectors .toList ()),
425- outputType ));
426- }
427- return Optional .empty ();
407+ var out = singularInputType .orElseThrow ().tryMatch (type , outputType );
408+
409+ return out .map (
410+ declaration -> {
411+ var coercedArgs = coerceArguments (operands , type );
412+ declaration .validateOutputType (coercedArgs , outputType );
413+ return generateBinding (call , out .get (), coercedArgs , outputType );
414+ });
428415 }
429416
430- private Optional <T > matchCoerced (C call , Type outputType , List <Expression > operands ) {
431-
417+ private Optional <T > matchCoerced (C call , Type outputType , List <Expression > expressions ) {
432418 // Convert the operands to the proper Substrait type
433- List <Type > allTypes =
419+ List <Type > operandTypes =
434420 call .getOperands ()
435421 .map (RexNode ::getType )
436422 .map (typeConverter ::toSubstrait )
437423 .collect (Collectors .toList ());
438424
439- // See if all the input types match the function
440- Optional <F > matchFunction = this .matcher .tryMatch (allTypes , outputType );
441- if (matchFunction .isPresent ()) {
442- List <Expression > coerced =
443- Streams .zip (
444- operands .stream (),
445- call .getOperands (),
446- (a , b ) -> {
447- Type type = typeConverter .toSubstrait (b .getType ());
448- return coerceArgument (a , type );
449- })
450- .collect (Collectors .toList ());
451-
452- return Optional .of (
453- generateBinding (
454- call ,
455- matchFunction .get (),
456- coerced .stream ().map (FunctionArg .class ::cast ).collect (Collectors .toList ()),
457- outputType ));
425+ // See if all the input types can be made to match the function
426+ Optional <F > matchFunction = signatureMatch (operandTypes , outputType );
427+ if (matchFunction .isEmpty ()) {
428+ return Optional .empty ();
458429 }
459430
460- return Optional .empty ();
431+ var coercedArgs =
432+ Streams .zip (
433+ expressions .stream (), operandTypes .stream (), FunctionConverter ::coerceArgument )
434+ .collect (Collectors .toList ());
435+ return Optional .of (generateBinding (call , matchFunction .get (), coercedArgs , outputType ));
461436 }
462437
463438 protected String getName () {
@@ -479,56 +454,30 @@ public interface GenericCall {
479454 * Coerced types according to an expected output type. Coercion is only done for type mismatches,
480455 * not for nullability or parameter mismatches.
481456 */
482- private static List <Expression > coerceArguments (List <Expression > arguments , Type type ) {
483- return arguments .stream ().map (a -> coerceArgument (a , type )).collect (Collectors .toList ());
457+ private static List <Expression > coerceArguments (List <Expression > arguments , Type targetType ) {
458+ return arguments .stream ().map (a -> coerceArgument (a , targetType )).collect (Collectors .toList ());
484459 }
485460
486461 private static Expression coerceArgument (Expression argument , Type type ) {
487- var typeMatches = isMatch (type , argument .getType ());
488- if (!typeMatches ) {
489- return ExpressionCreator .cast (type , argument , Expression .FailureBehavior .THROW_EXCEPTION );
462+ if (isMatch (type , argument .getType ())) {
463+ return argument ;
490464 }
491- return argument ;
465+
466+ return ExpressionCreator .cast (type , argument , Expression .FailureBehavior .THROW_EXCEPTION );
492467 }
493468
494469 protected abstract T generateBinding (
495- C call , F function , List <FunctionArg > arguments , Type outputType );
470+ C call , F function , List <? extends FunctionArg > arguments , Type outputType );
496471
497- public interface SingularArgumentMatcher <F > {
472+ @ FunctionalInterface
473+ private interface SingularArgumentMatcher <F > {
498474 Optional <F > tryMatch (Type type , Type outputType );
499475 }
500476
501- public interface SignatureMatcher <F > {
502- Optional <F > tryMatch (List <Type > types , Type outputType );
503- }
504-
505- private static SignatureMatcher chainedSignature (SignatureMatcher ... matchers ) {
506- return switch (matchers .length ) {
507- case 0 -> (types , outputType ) -> Optional .empty ();
508- case 1 -> matchers [0 ];
509- default -> (types , outputType ) -> {
510- for (SignatureMatcher m : matchers ) {
511- var t = m .tryMatch (types , outputType );
512- if (t .isPresent ()) {
513- return t ;
514- }
515- }
516- return Optional .empty ();
517- };
518- };
519- }
520-
521- private static boolean isMatch (Type inputType , ParameterizedType type ) {
522- if (type .isWildcard ()) {
523- return true ;
524- }
525- return inputType .accept (new IgnoreNullableAndParameters (type ));
526- }
527-
528- private static boolean isMatch (ParameterizedType inputType , ParameterizedType type ) {
529- if (type .isWildcard ()) {
477+ private static boolean isMatch (ParameterizedType actualType , ParameterizedType targetType ) {
478+ if (targetType .isWildcard ()) {
530479 return true ;
531480 }
532- return inputType .accept (new IgnoreNullableAndParameters (type ));
481+ return actualType .accept (new IgnoreNullableAndParameters (targetType ));
533482 }
534483}
0 commit comments