Skip to content

Commit f776252

Browse files
committed
Improve performance when parsing invalid queries
Use BailErrorStrategy as recommended by the ANTLR team.
1 parent b0a3bd9 commit f776252

File tree

5 files changed

+49
-35
lines changed

5 files changed

+49
-35
lines changed

core/trino-parser/src/main/java/io/trino/sql/jsonpath/PathParser.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,19 @@
1919
import io.trino.sql.jsonpath.tree.PathNode;
2020
import io.trino.sql.parser.ParsingException;
2121
import io.trino.sql.tree.NodeLocation;
22+
import org.antlr.v4.runtime.BailErrorStrategy;
2223
import org.antlr.v4.runtime.BaseErrorListener;
2324
import org.antlr.v4.runtime.CharStreams;
2425
import org.antlr.v4.runtime.CommonToken;
2526
import org.antlr.v4.runtime.CommonTokenStream;
27+
import org.antlr.v4.runtime.DefaultErrorStrategy;
2628
import org.antlr.v4.runtime.ParserRuleContext;
2729
import org.antlr.v4.runtime.RecognitionException;
2830
import org.antlr.v4.runtime.Recognizer;
2931
import org.antlr.v4.runtime.Token;
3032
import org.antlr.v4.runtime.atn.PredictionMode;
3133
import org.antlr.v4.runtime.misc.Pair;
34+
import org.antlr.v4.runtime.misc.ParseCancellationException;
3235
import org.antlr.v4.runtime.tree.TerminalNode;
3336

3437
import java.util.Arrays;
@@ -95,20 +98,20 @@ public PathNode parseJsonPath(String path)
9598
lexer.addErrorListener(errorListener);
9699

97100
parser.removeErrorListeners();
98-
parser.addErrorListener(errorListener);
99101

100102
ParserRuleContext tree;
101103
try {
102104
// first, try parsing with potentially faster SLL mode
103105
parser.getInterpreter().setPredictionMode(PredictionMode.SLL);
106+
parser.setErrorHandler(new BailErrorStrategy());
104107
tree = parser.path();
105108
}
106-
catch (ParsingException ex) {
109+
catch (ParseCancellationException e) {
107110
// if we fail, parse with LL mode
108-
tokenStream.seek(0); // rewind input stream
109111
parser.reset();
110-
111112
parser.getInterpreter().setPredictionMode(PredictionMode.LL);
113+
parser.setErrorHandler(new DefaultErrorStrategy());
114+
parser.addErrorListener(errorListener);
112115
tree = parser.path();
113116
}
114117

core/trino-parser/src/main/java/io/trino/sql/parser/SqlParser.java

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import io.trino.sql.tree.RowPattern;
2626
import io.trino.sql.tree.Statement;
2727
import org.antlr.v4.runtime.ANTLRErrorListener;
28+
import org.antlr.v4.runtime.BailErrorStrategy;
2829
import org.antlr.v4.runtime.BaseErrorListener;
2930
import org.antlr.v4.runtime.CharStreams;
3031
import org.antlr.v4.runtime.CommonToken;
@@ -39,6 +40,7 @@
3940
import org.antlr.v4.runtime.atn.PredictionMode;
4041
import org.antlr.v4.runtime.misc.Interval;
4142
import org.antlr.v4.runtime.misc.Pair;
43+
import org.antlr.v4.runtime.misc.ParseCancellationException;
4244
import org.antlr.v4.runtime.tree.TerminalNode;
4345

4446
import java.util.Arrays;
@@ -144,42 +146,27 @@ private Node invokeParser(String name, String sql, Optional<NodeLocation> locati
144146
SqlBaseParser parser = new SqlBaseParser(tokenStream);
145147
initializer.accept(lexer, parser);
146148

147-
// Override the default error strategy to not attempt inserting or deleting a token.
148-
// Otherwise, it messes up error reporting
149-
parser.setErrorHandler(new DefaultErrorStrategy()
150-
{
151-
@Override
152-
public Token recoverInline(Parser recognizer)
153-
throws RecognitionException
154-
{
155-
if (nextTokensContext == null) {
156-
throw new InputMismatchException(recognizer);
157-
}
158-
throw new InputMismatchException(recognizer, nextTokensState, nextTokensContext);
159-
}
160-
});
161-
162149
parser.addParseListener(new PostProcessor(Arrays.asList(parser.getRuleNames()), parser));
163150

164151
lexer.removeErrorListeners();
165152
lexer.addErrorListener(LEXER_ERROR_LISTENER);
166153

167154
parser.removeErrorListeners();
168-
parser.addErrorListener(PARSER_ERROR_HANDLER);
169155

170156
ParserRuleContext tree;
171157
try {
172158
try {
173159
// first, try parsing with potentially faster SLL mode
174160
parser.getInterpreter().setPredictionMode(PredictionMode.SLL);
161+
parser.setErrorHandler(new BailErrorStrategy());
175162
tree = parseFunction.apply(parser);
176163
}
177-
catch (ParsingException ex) {
164+
catch (ParseCancellationException e) {
178165
// if we fail, parse with LL mode
179-
tokenStream.seek(0); // rewind input stream
180166
parser.reset();
181-
182167
parser.getInterpreter().setPredictionMode(PredictionMode.LL);
168+
parser.setErrorHandler(new NonRecoveringErrorStrategy());
169+
parser.addErrorListener(PARSER_ERROR_HANDLER);
183170
tree = parseFunction.apply(parser);
184171
}
185172
}
@@ -203,6 +190,21 @@ public Token recoverInline(Parser recognizer)
203190
}
204191
}
205192

193+
// Override the default error strategy to not attempt inserting or deleting a token.
194+
// Otherwise, it messes up error reporting.
195+
private static final class NonRecoveringErrorStrategy
196+
extends DefaultErrorStrategy
197+
{
198+
@Override
199+
public Token recoverInline(Parser recognizer)
200+
{
201+
if (nextTokensContext == null) {
202+
throw new InputMismatchException(recognizer);
203+
}
204+
throw new InputMismatchException(recognizer, nextTokensState, nextTokensContext);
205+
}
206+
}
207+
206208
private static class PostProcessor
207209
extends SqlBaseBaseListener
208210
{

core/trino-parser/src/main/java/io/trino/type/TypeCalculation.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,16 @@
2727
import io.trino.sql.parser.ParsingException;
2828
import io.trino.sql.tree.NodeLocation;
2929
import org.antlr.v4.runtime.ANTLRErrorListener;
30+
import org.antlr.v4.runtime.BailErrorStrategy;
3031
import org.antlr.v4.runtime.BaseErrorListener;
3132
import org.antlr.v4.runtime.CharStreams;
3233
import org.antlr.v4.runtime.CommonTokenStream;
34+
import org.antlr.v4.runtime.DefaultErrorStrategy;
3335
import org.antlr.v4.runtime.ParserRuleContext;
3436
import org.antlr.v4.runtime.RecognitionException;
3537
import org.antlr.v4.runtime.Recognizer;
3638
import org.antlr.v4.runtime.atn.PredictionMode;
39+
import org.antlr.v4.runtime.misc.ParseCancellationException;
3740

3841
import java.math.BigInteger;
3942
import java.util.Map;
@@ -85,20 +88,20 @@ private static ParserRuleContext parseTypeCalculation(String calculation)
8588
lexer.addErrorListener(ERROR_LISTENER);
8689

8790
parser.removeErrorListeners();
88-
parser.addErrorListener(ERROR_LISTENER);
8991

9092
ParserRuleContext tree;
9193
try {
9294
// first, try parsing with potentially faster SLL mode
9395
parser.getInterpreter().setPredictionMode(PredictionMode.SLL);
96+
parser.setErrorHandler(new BailErrorStrategy());
9497
tree = parser.typeCalculation();
9598
}
96-
catch (ParsingException ex) {
99+
catch (ParseCancellationException e) {
97100
// if we fail, parse with LL mode
98-
tokenStream.seek(0); // rewind input stream
99101
parser.reset();
100-
101102
parser.getInterpreter().setPredictionMode(PredictionMode.LL);
103+
parser.setErrorHandler(new DefaultErrorStrategy());
104+
parser.addErrorListener(ERROR_LISTENER);
102105
tree = parser.typeCalculation();
103106
}
104107
return tree;

plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ExpressionMappingParser.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,16 @@
1515

1616
import com.google.common.collect.ImmutableMap;
1717
import org.antlr.v4.runtime.ANTLRErrorListener;
18+
import org.antlr.v4.runtime.BailErrorStrategy;
1819
import org.antlr.v4.runtime.BaseErrorListener;
1920
import org.antlr.v4.runtime.CharStreams;
2021
import org.antlr.v4.runtime.CommonTokenStream;
22+
import org.antlr.v4.runtime.DefaultErrorStrategy;
2123
import org.antlr.v4.runtime.ParserRuleContext;
2224
import org.antlr.v4.runtime.RecognitionException;
2325
import org.antlr.v4.runtime.Recognizer;
2426
import org.antlr.v4.runtime.atn.PredictionMode;
27+
import org.antlr.v4.runtime.misc.ParseCancellationException;
2528

2629
import java.util.Map;
2730
import java.util.Set;
@@ -69,20 +72,20 @@ public Object invokeParser(String input, Function<ConnectorExpressionPatternPars
6972
lexer.addErrorListener(ERROR_LISTENER);
7073

7174
parser.removeErrorListeners();
72-
parser.addErrorListener(ERROR_LISTENER);
7375

7476
ParserRuleContext tree;
7577
try {
7678
// first, try parsing with potentially faster SLL mode
7779
parser.getInterpreter().setPredictionMode(PredictionMode.SLL);
80+
parser.setErrorHandler(new BailErrorStrategy());
7881
tree = parseFunction.apply(parser);
7982
}
80-
catch (IllegalArgumentException ex) {
83+
catch (ParseCancellationException e) {
8184
// if we fail, parse with LL mode
82-
tokenStream.seek(0); // rewind input stream
8385
parser.reset();
84-
8586
parser.getInterpreter().setPredictionMode(PredictionMode.LL);
87+
parser.setErrorHandler(new DefaultErrorStrategy());
88+
parser.addErrorListener(ERROR_LISTENER);
8689
tree = parseFunction.apply(parser);
8790
}
8891
return new ExpressionPatternBuilder(typeClasses).visit(tree);

plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/expression/SparkExpressionParser.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,16 @@
1515

1616
import com.google.common.annotations.VisibleForTesting;
1717
import org.antlr.v4.runtime.ANTLRErrorListener;
18+
import org.antlr.v4.runtime.BailErrorStrategy;
1819
import org.antlr.v4.runtime.BaseErrorListener;
1920
import org.antlr.v4.runtime.CharStreams;
2021
import org.antlr.v4.runtime.CommonTokenStream;
22+
import org.antlr.v4.runtime.DefaultErrorStrategy;
2123
import org.antlr.v4.runtime.ParserRuleContext;
2224
import org.antlr.v4.runtime.RecognitionException;
2325
import org.antlr.v4.runtime.Recognizer;
2426
import org.antlr.v4.runtime.atn.PredictionMode;
27+
import org.antlr.v4.runtime.misc.ParseCancellationException;
2528

2629
import java.util.function.Function;
2730

@@ -68,20 +71,20 @@ private static Object invokeParser(String input, Function<SparkExpressionBasePar
6871
lexer.addErrorListener(ERROR_LISTENER);
6972

7073
parser.removeErrorListeners();
71-
parser.addErrorListener(ERROR_LISTENER);
7274

7375
ParserRuleContext tree;
7476
try {
7577
// first, try parsing with potentially faster SLL mode
7678
parser.getInterpreter().setPredictionMode(PredictionMode.SLL);
79+
parser.setErrorHandler(new BailErrorStrategy());
7780
tree = parseFunction.apply(parser);
7881
}
79-
catch (ParsingException ex) {
82+
catch (ParseCancellationException e) {
8083
// if we fail, parse with LL mode
81-
tokenStream.seek(0); // rewind input stream
8284
parser.reset();
83-
8485
parser.getInterpreter().setPredictionMode(PredictionMode.LL);
86+
parser.setErrorHandler(new DefaultErrorStrategy());
87+
parser.addErrorListener(ERROR_LISTENER);
8588
tree = parseFunction.apply(parser);
8689
}
8790
return new SparkExpressionBuilder().visit(tree);

0 commit comments

Comments
 (0)