Skip to content

Commit f712f7a

Browse files
committed
Add support for PostgreSQL-style shorthand casts
Standard SQL supports static method calls on types via the :: operator. While this syntax is generally incompatible with PostgreSQL shorthand cast syntax, there's a subset of that can be safely repurposed to support that functionality. The SQL specification defines static method calls as: <static method invocation> ::= <path-resolved user-defined type name> <double colon> <method name> [ <SQL argument list> ] where <path-resolved user-defined type name> translates to: <user-defined type name> ::= [ <schema name> <period> ] <qualified identifier> To support casts, we need to extend the rule to support arbitrary expression as the target of the invocation. To disambiguate a static method call from a cast, we distinguish between type-producing expressions and expressions that produces regular values. For the latter, if the method matches the name of a well-known type, we treat it as a cast. Otherwise, the expression assumed to be a static method call and fail the evaluation with a "not yet supported" error. One limitation is that casts are only supported for simple types. Parametric types are not yet supported, but it wouldn't be too hard to add. Types whose name don't match the syntax of SQL function calls are not supported either and will be harder to support. It will require introducing type-producing expressions into the language and making types a first-class expression.
1 parent 6a9cc5a commit f712f7a

File tree

10 files changed

+338
-50
lines changed

10 files changed

+338
-50
lines changed

core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4

Lines changed: 47 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -582,66 +582,70 @@ valueExpression
582582
;
583583

584584
primaryExpression
585-
: literal #literals
586-
| QUESTION_MARK #parameter
587-
| POSITION '(' valueExpression IN valueExpression ')' #position
588-
| '(' expression (',' expression)+ ')' #rowConstructor
589-
| ROW '(' expression (',' expression)* ')' #rowConstructor
585+
: literal #literals
586+
| QUESTION_MARK #parameter
587+
| POSITION '(' valueExpression IN valueExpression ')' #position
588+
| '(' expression (',' expression)+ ')' #rowConstructor
589+
| ROW '(' expression (',' expression)* ')' #rowConstructor
590590
| name=LISTAGG '(' setQuantifier? expression (',' string)?
591591
(ON OVERFLOW listAggOverflowBehavior)? ')'
592592
(WITHIN GROUP '(' orderBy ')')
593-
filter? over? #listagg
593+
filter? over? #listagg
594594
| processingMode? qualifiedName '(' (label=identifier '.')? ASTERISK ')'
595-
filter? over? #functionCall
595+
filter? over? #functionCall
596596
| processingMode? qualifiedName '(' (setQuantifier? expression (',' expression)*)?
597-
orderBy? ')' filter? (nullTreatment? over)? #functionCall
598-
| identifier over #measure
599-
| identifier '->' expression #lambda
600-
| '(' (identifier (',' identifier)*)? ')' '->' expression #lambda
601-
| '(' query ')' #subqueryExpression
597+
orderBy? ')' filter? (nullTreatment? over)? #functionCall
598+
| identifier over #measure
599+
| identifier '->' expression #lambda
600+
| '(' (identifier (',' identifier)*)? ')' '->' expression #lambda
601+
| '(' query ')' #subqueryExpression
602602
// This is an extension to ANSI SQL, which considers EXISTS to be a <boolean expression>
603-
| EXISTS '(' query ')' #exists
604-
| CASE operand=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase
605-
| CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase
606-
| CAST '(' expression AS type ')' #cast
607-
| TRY_CAST '(' expression AS type ')' #cast
608-
| ARRAY '[' (expression (',' expression)*)? ']' #arrayConstructor
609-
| '[' (expression (',' expression)*)? ']' #arrayConstructor
610-
| value=primaryExpression '[' index=valueExpression ']' #subscript
611-
| identifier #columnReference
612-
| base=primaryExpression '.' fieldName=identifier #dereference
613-
| name=CURRENT_DATE #currentDate
614-
| name=CURRENT_TIME ('(' precision=INTEGER_VALUE ')')? #currentTime
615-
| name=CURRENT_TIMESTAMP ('(' precision=INTEGER_VALUE ')')? #currentTimestamp
616-
| name=LOCALTIME ('(' precision=INTEGER_VALUE ')')? #localTime
617-
| name=LOCALTIMESTAMP ('(' precision=INTEGER_VALUE ')')? #localTimestamp
618-
| name=CURRENT_USER #currentUser
619-
| name=CURRENT_CATALOG #currentCatalog
620-
| name=CURRENT_SCHEMA #currentSchema
621-
| name=CURRENT_PATH #currentPath
622-
| TRIM '(' (trimsSpecification? trimChar=valueExpression? FROM)?
623-
trimSource=valueExpression ')' #trim
624-
| TRIM '(' trimSource=valueExpression ',' trimChar=valueExpression ')' #trim
625-
| SUBSTRING '(' valueExpression FROM valueExpression (FOR valueExpression)? ')' #substring
626-
| NORMALIZE '(' valueExpression (',' normalForm)? ')' #normalize
627-
| EXTRACT '(' identifier FROM valueExpression ')' #extract
628-
| '(' expression ')' #parenthesizedExpression
629-
| GROUPING '(' (qualifiedName (',' qualifiedName)*)? ')' #groupingOperation
630-
| JSON_EXISTS '(' jsonPathInvocation (jsonExistsErrorBehavior ON ERROR)? ')' #jsonExists
603+
| EXISTS '(' query ')' #exists
604+
| CASE operand=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase
605+
| CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase
606+
| CAST '(' expression AS type ')' #cast
607+
| TRY_CAST '(' expression AS type ')' #cast
608+
// the target is a primaryExpression to support PostgreSQL-style casts
609+
// of the form <complex expression>::<type>, which are syntactically ambiguous with
610+
// static method calls
611+
| primaryExpression DOUBLE_COLON identifier ('(' (expression (',' expression)*)? ')')? #staticMethodCall
612+
| ARRAY '[' (expression (',' expression)*)? ']' #arrayConstructor
613+
| '[' (expression (',' expression)*)? ']' #arrayConstructor
614+
| value=primaryExpression '[' index=valueExpression ']' #subscript
615+
| identifier #columnReference
616+
| base=primaryExpression '.' fieldName=identifier #dereference
617+
| name=CURRENT_DATE #currentDate
618+
| name=CURRENT_TIME ('(' precision=INTEGER_VALUE ')')? #currentTime
619+
| name=CURRENT_TIMESTAMP ('(' precision=INTEGER_VALUE ')')? #currentTimestamp
620+
| name=LOCALTIME ('(' precision=INTEGER_VALUE ')')? #localTime
621+
| name=LOCALTIMESTAMP ('(' precision=INTEGER_VALUE ')')? #localTimestamp
622+
| name=CURRENT_USER #currentUser
623+
| name=CURRENT_CATALOG #currentCatalog
624+
| name=CURRENT_SCHEMA #currentSchema
625+
| name=CURRENT_PATH #currentPath
626+
| TRIM '(' (trimsSpecification? trimChar=valueExpression? FROM)?
627+
trimSource=valueExpression ')' #trim
628+
| TRIM '(' trimSource=valueExpression ',' trimChar=valueExpression ')' #trim
629+
| SUBSTRING '(' valueExpression FROM valueExpression (FOR valueExpression)? ')' #substring
630+
| NORMALIZE '(' valueExpression (',' normalForm)? ')' #normalize
631+
| EXTRACT '(' identifier FROM valueExpression ')' #extract
632+
| '(' expression ')' #parenthesizedExpression
633+
| GROUPING '(' (qualifiedName (',' qualifiedName)*)? ')' #groupingOperation
634+
| JSON_EXISTS '(' jsonPathInvocation (jsonExistsErrorBehavior ON ERROR)? ')' #jsonExists
631635
| JSON_VALUE '('
632636
jsonPathInvocation
633637
(RETURNING type)?
634638
(emptyBehavior=jsonValueBehavior ON EMPTY)?
635639
(errorBehavior=jsonValueBehavior ON ERROR)?
636-
')' #jsonValue
640+
')' #jsonValue
637641
| JSON_QUERY '('
638642
jsonPathInvocation
639643
(RETURNING type (FORMAT jsonRepresentation)?)?
640644
(jsonQueryWrapperBehavior WRAPPER)?
641645
((KEEP | OMIT) QUOTES (ON SCALAR TEXT_STRING)?)?
642646
(emptyBehavior=jsonQueryBehavior ON EMPTY)?
643647
(errorBehavior=jsonQueryBehavior ON ERROR)?
644-
')' #jsonQuery
648+
')' #jsonQuery
645649
| JSON_OBJECT '('
646650
(
647651
jsonObjectMember (',' jsonObjectMember)*
@@ -1119,6 +1123,7 @@ DISTINCT: 'DISTINCT';
11191123
DISTRIBUTED: 'DISTRIBUTED';
11201124
DO: 'DO';
11211125
DOUBLE: 'DOUBLE';
1126+
DOUBLE_COLON: '::';
11221127
DROP: 'DROP';
11231128
ELSE: 'ELSE';
11241129
EMPTY: 'EMPTY';

core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@
145145
import io.trino.sql.tree.SkipTo;
146146
import io.trino.sql.tree.SortItem;
147147
import io.trino.sql.tree.SortItem.Ordering;
148+
import io.trino.sql.tree.StaticMethodCall;
148149
import io.trino.sql.tree.StringLiteral;
149150
import io.trino.sql.tree.SubqueryExpression;
150151
import io.trino.sql.tree.SubscriptExpression;
@@ -1704,6 +1705,44 @@ private void analyzeFrameRangeOffset(Expression offsetValue, FrameBound.Type bou
17041705
frameBoundCalculations.put(NodeRef.of(offsetValue), function);
17051706
}
17061707

1708+
@Override
1709+
protected Type visitStaticMethodCall(StaticMethodCall node, Context context)
1710+
{
1711+
// PostgreSQL-style are syntactically ambiguous with static method calls. So, static method call semantics take precendence.
1712+
// A static method call is characterized by the target being an expression whose type is "type". This not yet supported
1713+
// as a first-class concept, so we fake it by analyzing the expression normally. If the analysis succeeds, we treat it as
1714+
// the target of a cast. If the analysis fails, check whether the target is an identifier matching a known type name.
1715+
try {
1716+
process(node.getTarget(), context);
1717+
}
1718+
catch (TrinoException e) {
1719+
// assume it might be a type name, so check if it's an identifier matching a known type
1720+
if (node.getTarget() instanceof Identifier target) {
1721+
try {
1722+
plannerContext.getTypeManager().fromSqlType(target.getValue());
1723+
}
1724+
catch (TypeNotFoundException typeException) {
1725+
// since the type is not found, treat the expression as normal expression that failed analysis
1726+
throw e;
1727+
}
1728+
}
1729+
throw semanticException(NOT_SUPPORTED, node, "Static method calls are not supported");
1730+
}
1731+
1732+
if (!node.getArguments().isEmpty()) {
1733+
throw semanticException(NOT_SUPPORTED, node, "Static method calls are not supported");
1734+
}
1735+
1736+
// assume it's a PostgreSQL-style cast unless result type is not a known type
1737+
try {
1738+
Type type = plannerContext.getTypeManager().fromSqlType(node.getMethod().getValue());
1739+
return setExpressionType(node, type);
1740+
}
1741+
catch (Exception e) {
1742+
throw semanticException(NOT_SUPPORTED, node, "Static method calls are not supported");
1743+
}
1744+
}
1745+
17071746
@Override
17081747
protected Type visitWindowOperation(WindowOperation node, Context context)
17091748
{

core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@
109109
import io.trino.sql.tree.Row;
110110
import io.trino.sql.tree.SearchedCaseExpression;
111111
import io.trino.sql.tree.SimpleCaseExpression;
112+
import io.trino.sql.tree.StaticMethodCall;
112113
import io.trino.sql.tree.StringLiteral;
113114
import io.trino.sql.tree.SubscriptExpression;
114115
import io.trino.sql.tree.Trim;
@@ -316,6 +317,7 @@ private io.trino.sql.ir.Expression translate(Expression expr, boolean isRoot)
316317
case io.trino.sql.tree.FieldReference expression -> translate(expression);
317318
case Identifier expression -> translate(expression);
318319
case FunctionCall expression -> translate(expression);
320+
case StaticMethodCall expression -> translate(expression);
319321
case DereferenceExpression expression -> translate(expression);
320322
case Array expression -> translate(expression);
321323
case CurrentCatalog expression -> translate(expression);
@@ -663,6 +665,14 @@ private io.trino.sql.ir.Expression translate(FunctionCall expression)
663665
.collect(toImmutableList()));
664666
}
665667

668+
private io.trino.sql.ir.Expression translate(StaticMethodCall expression)
669+
{
670+
// Currently, only PostgreSQL-style cast shorthand expressions are supported
671+
return new io.trino.sql.ir.Cast(
672+
translateExpression(expression.getTarget()),
673+
analysis.getType(expression));
674+
}
675+
666676
private io.trino.sql.ir.Expression translate(DereferenceExpression expression)
667677
{
668678
if (analysis.isColumnReference(expression)) {
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.operator.scalar;
15+
16+
import io.trino.spi.type.DoubleType;
17+
import io.trino.sql.query.QueryAssertions;
18+
import org.junit.jupiter.api.AfterAll;
19+
import org.junit.jupiter.api.BeforeAll;
20+
import org.junit.jupiter.api.Test;
21+
import org.junit.jupiter.api.TestInstance;
22+
import org.junit.jupiter.api.parallel.Execution;
23+
24+
import static org.assertj.core.api.Assertions.assertThat;
25+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
26+
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS;
27+
import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT;
28+
29+
@TestInstance(PER_CLASS)
30+
@Execution(CONCURRENT)
31+
public class TestStaticMethodCall
32+
{
33+
private QueryAssertions assertions;
34+
35+
@BeforeAll
36+
public void init()
37+
{
38+
assertions = new QueryAssertions();
39+
}
40+
41+
@AfterAll
42+
public void teardown()
43+
{
44+
assertions.close();
45+
assertions = null;
46+
}
47+
48+
@Test
49+
void testPostgreSqlStyleCast()
50+
{
51+
assertThat(assertions.expression("1::double"))
52+
.hasType(DoubleType.DOUBLE)
53+
.isEqualTo(1.0);
54+
55+
assertThat(assertions.expression("(a + b)::double")
56+
.binding("a", "1")
57+
.binding("b", "2"))
58+
.hasType(DoubleType.DOUBLE)
59+
.isEqualTo(3.0);
60+
}
61+
62+
@Test
63+
void testCall()
64+
{
65+
assertThatThrownBy(() -> assertions.expression("1::double(2)").evaluate())
66+
.hasMessage("line 1:13: Static method calls are not supported");
67+
68+
assertThatThrownBy(() -> assertions.expression("1::foo").evaluate())
69+
.hasMessage("line 1:13: Static method calls are not supported");
70+
71+
assertThatThrownBy(() -> assertions.expression("integer::foo").evaluate())
72+
.hasMessage("line 1:19: Static method calls are not supported");
73+
74+
assertThatThrownBy(() -> assertions.expression("integer::foo(1, 2)").evaluate())
75+
.hasMessage("line 1:19: Static method calls are not supported");
76+
}
77+
}

core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
import io.trino.sql.tree.SimpleGroupBy;
9292
import io.trino.sql.tree.SkipTo;
9393
import io.trino.sql.tree.SortItem;
94+
import io.trino.sql.tree.StaticMethodCall;
9495
import io.trino.sql.tree.StringLiteral;
9596
import io.trino.sql.tree.SubqueryExpression;
9697
import io.trino.sql.tree.SubscriptExpression;
@@ -467,6 +468,24 @@ protected String visitFunctionCall(FunctionCall node, Void context)
467468
return builder.toString();
468469
}
469470

471+
@Override
472+
protected String visitStaticMethodCall(StaticMethodCall node, Void context)
473+
{
474+
StringBuilder builder = new StringBuilder();
475+
476+
builder.append(process(node.getTarget(), context))
477+
.append("::")
478+
.append(process(node.getMethod(), context));
479+
480+
if (!node.getArguments().isEmpty()) {
481+
builder.append('(')
482+
.append(joinExpressions(node.getArguments()))
483+
.append(')');
484+
}
485+
486+
return builder.toString();
487+
}
488+
470489
@Override
471490
protected String visitWindowOperation(WindowOperation node, Void context)
472491
{

core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@
276276
import io.trino.sql.tree.SortItem;
277277
import io.trino.sql.tree.StartTransaction;
278278
import io.trino.sql.tree.Statement;
279+
import io.trino.sql.tree.StaticMethodCall;
279280
import io.trino.sql.tree.StringLiteral;
280281
import io.trino.sql.tree.SubqueryExpression;
281282
import io.trino.sql.tree.SubscriptExpression;
@@ -3081,6 +3082,16 @@ else if (processingMode.FINAL() != null) {
30813082
arguments);
30823083
}
30833084

3085+
@Override
3086+
public Node visitStaticMethodCall(SqlBaseParser.StaticMethodCallContext context)
3087+
{
3088+
return new StaticMethodCall(
3089+
getLocation(context.DOUBLE_COLON()),
3090+
(Expression) visit(context.primaryExpression()),
3091+
(Identifier) visit(context.identifier()),
3092+
visit(context.expression(), Expression.class));
3093+
}
3094+
30843095
@Override
30853096
public Node visitMeasure(SqlBaseParser.MeasureContext context)
30863097
{

core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,11 @@ protected R visitFunctionCall(FunctionCall node, C context)
322322
return visitExpression(node, context);
323323
}
324324

325+
protected R visitStaticMethodCall(StaticMethodCall node, C context)
326+
{
327+
return visitExpression(node, context);
328+
}
329+
325330
protected R visitProcessingMode(ProcessingMode node, C context)
326331
{
327332
return visitNode(node, context);

0 commit comments

Comments
 (0)