Skip to content

Commit 3145f65

Browse files
committed
[sidecar] Fix array_sort lambda function failure when sidecar is enabled and add e2e lambda function tests
1 parent 52fae25 commit 3145f65

File tree

5 files changed

+104
-9
lines changed

5 files changed

+104
-9
lines changed

presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/SemanticErrorCode.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ public enum SemanticErrorCode
5656
FUNCTION_NOT_FOUND,
5757
INVALID_FUNCTION_NAME,
5858
DUPLICATE_PARAMETER_NAME,
59+
EXCEPTIONS_WHEN_RESOLVING_FUNCTIONS,
5960

6061
ORDER_BY_MUST_BE_IN_SELECT,
6162
ORDER_BY_MUST_BE_IN_AGGREGATE,

presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/SemanticException.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ public class SemanticException
2626
{
2727
private final SemanticErrorCode code;
2828
private final Optional<NodeLocation> location;
29+
private final String formattedMessage;
2930

3031
public SemanticException(SemanticErrorCode code, String format, Object... args)
3132
{
@@ -44,10 +45,18 @@ public SemanticException(SemanticErrorCode code, Optional<NodeLocation> location
4445

4546
public SemanticException(SemanticErrorCode code, Throwable cause, Optional<NodeLocation> location, String format, Object... args)
4647
{
47-
super(formatMessage(format, location, args), cause);
48+
super(cause);
4849

4950
this.code = requireNonNull(code, "code is null");
5051
this.location = requireNonNull(location, "location is null");
52+
requireNonNull(format, "format is null");
53+
this.formattedMessage = formatMessage(format, location, args);
54+
}
55+
56+
@Override
57+
public String getMessage()
58+
{
59+
return formattedMessage;
5160
}
5261

5362
// TODO: Should be replaced with analyzer agnostic location
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package com.facebook.presto.sql.analyzer;
15+
16+
import java.util.List;
17+
18+
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.EXCEPTIONS_WHEN_RESOLVING_FUNCTIONS;
19+
import static java.lang.String.format;
20+
21+
public class SignatureMatchingException
22+
extends SemanticException
23+
{
24+
public SignatureMatchingException(
25+
String prefix,
26+
List<SemanticException> failedExceptions)
27+
{
28+
super(EXCEPTIONS_WHEN_RESOLVING_FUNCTIONS, formatMessage(prefix, failedExceptions));
29+
}
30+
31+
private static String formatMessage(String formatString, List<SemanticException> failedExceptions)
32+
{
33+
StringBuilder sb = new StringBuilder(formatString).append("\n");
34+
for (int i = 0; i < failedExceptions.size(); i++) {
35+
sb.append(format(" Exception %d: %s%n", i + 1, failedExceptions.get(i).getMessage()));
36+
}
37+
return sb.toString();
38+
}
39+
}

presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionSignatureMatcher.java

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import com.facebook.presto.spi.function.FunctionKind;
2020
import com.facebook.presto.spi.function.Signature;
2121
import com.facebook.presto.spi.function.SqlFunction;
22+
import com.facebook.presto.sql.analyzer.SemanticException;
23+
import com.facebook.presto.sql.analyzer.SignatureMatchingException;
2224
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
2325
import com.google.common.base.Joiner;
2426
import com.google.common.collect.ImmutableList;
@@ -124,15 +126,28 @@ private Optional<Signature> matchFunction(Collection<? extends SqlFunction> cand
124126
private List<ApplicableFunction> identifyApplicableFunctions(Collection<? extends SqlFunction> candidates, List<TypeSignatureProvider> actualParameters, boolean allowCoercion)
125127
{
126128
ImmutableList.Builder<ApplicableFunction> applicableFunctions = ImmutableList.builder();
129+
ImmutableList.Builder<SemanticException> semanticExceptions = ImmutableList.builder();
127130
for (SqlFunction function : candidates) {
128131
Signature declaredSignature = function.getSignature();
129-
Optional<Signature> boundSignature = new SignatureBinder(functionAndTypeManager, declaredSignature, allowCoercion)
130-
.bind(actualParameters);
131-
if (boundSignature.isPresent()) {
132-
applicableFunctions.add(new ApplicableFunction(declaredSignature, boundSignature.get(), function.isCalledOnNullInput()));
132+
try {
133+
Optional<Signature> boundSignature = new SignatureBinder(functionAndTypeManager, declaredSignature, allowCoercion)
134+
.bind(actualParameters);
135+
boundSignature.ifPresent(signature -> applicableFunctions.add(new ApplicableFunction(declaredSignature, signature, function.isCalledOnNullInput())));
133136
}
137+
catch (SemanticException e) {
138+
semanticExceptions.add(e);
139+
}
140+
}
141+
142+
List<ApplicableFunction> applicableFunctionsList = applicableFunctions.build();
143+
List<SemanticException> semanticExceptionList = semanticExceptions.build();
144+
if (applicableFunctionsList.isEmpty() && !semanticExceptionList.isEmpty()) {
145+
decideAndThrow(semanticExceptionList,
146+
candidates.stream().findFirst()
147+
.map(function -> function.getSignature().getName().getObjectName())
148+
.orElse(""));
134149
}
135-
return applicableFunctions.build();
150+
return applicableFunctionsList;
136151
}
137152

138153
private List<ApplicableFunction> selectMostSpecificFunctions(List<ApplicableFunction> applicableFunctions, List<TypeSignatureProvider> parameters)
@@ -287,6 +302,22 @@ private static boolean returnsNullOnGivenInputTypes(ApplicableFunction applicabl
287302
return true;
288303
}
289304

305+
/**
306+
* Decides which exception to throw based on the number of failed attempts.
307+
* If there's only one SemanticException, it throws that SemanticException directly.
308+
* If there are multiple SemanticExceptions, it throws the SignatureMatchingException.
309+
*/
310+
private static void decideAndThrow(List<SemanticException> failedExceptions, String functionName)
311+
throws SemanticException
312+
{
313+
if (failedExceptions.size() == 1) {
314+
throw failedExceptions.get(0);
315+
}
316+
else {
317+
throw new SignatureMatchingException(format("Failed to find matching function signature for %s, matching failures: ", functionName), failedExceptions);
318+
}
319+
}
320+
290321
static String constructFunctionNotFoundErrorMessage(QualifiedObjectName functionName, List<TypeSignatureProvider> parameterTypes, Collection<? extends SqlFunction> candidates)
291322
{
292323
String name = toConciseFunctionName(functionName);

presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,25 @@ public void testWindowFunctions()
188188
}
189189

190190
@Test
191-
public void testArraySort()
191+
public void testLambdaFunctions()
192192
{
193-
assertQueryFails("SELECT array_sort(quantities, (x, y) -> if (x < y, 1, if (x > y, -1, 0))) FROM orders_ex",
194-
"line 1:31: Expected a lambda that takes 1 argument\\(s\\) but got 2");
193+
// These function signatures are only supported in the native execution engine
194+
assertQuerySucceeds("select array_sort(array[row('apples', 23), row('bananas', 12), row('grapes', 44)], x -> x[2])");
195+
assertQuerySucceeds("SELECT array_sort(quantities, x -> abs(x)) FROM orders_ex");
196+
assertQuerySucceeds("SELECT array_sort(quantities, (x, y) -> if (x < y, cast(1 as bigint), if (x > y, cast(-1 as bigint), cast(0 as bigint)))) FROM orders_ex");
197+
198+
assertQuery("SELECT array_sort(map_keys(map_union(quantity_by_linenumber))) FROM orders_ex");
199+
assertQuery("SELECT filter(quantities, q -> q > 10) FROM orders_ex");
200+
assertQuery("SELECT all_match(shuffle(quantities), x -> (x > 500.0)) FROM orders_ex");
201+
assertQuery("SELECT any_match(quantities, x -> TRY(((10 / x) > 2))) FROM orders_ex");
202+
assertQuery("SELECT TRY(none_match(quantities, x -> ((10 / x) > 2))) FROM orders_ex");
203+
assertQuery("SELECT reduce(array[nationkey, regionkey], 103, (s, x) -> s + x, s -> s) FROM nation");
204+
assertQuery("SELECT transform(array[1, 2, 3], x -> x * regionkey + nationkey) FROM nation");
205+
assertQueryFails(
206+
"SELECT array_sort(quantities, (x, y, z) -> if (x < y + z, cast(1 as bigint), if (x > y + z, cast(-1 as bigint), cast(0 as bigint)))) FROM orders_ex",
207+
Pattern.quote("Failed to find matching function signature for array_sort, matching failures: \n" +
208+
" Exception 1: line 1:31: Expected a lambda that takes 1 argument(s) but got 3\n" +
209+
" Exception 2: line 1:31: Expected a lambda that takes 2 argument(s) but got 3\n"));
195210
}
196211

197212
@Test

0 commit comments

Comments
 (0)