Skip to content

Commit 9ff0e64

Browse files
committed
[AST] Ensure the ASTWalker always stops
Fix up a few cases where we weren't propagating the stop action correctly, sprinkling in some `LLVM_NODISCARD` to help avoid the issue in the future.
1 parent ee773d9 commit 9ff0e64

File tree

1 file changed

+43
-14
lines changed

1 file changed

+43
-14
lines changed

lib/AST/ASTWalker.cpp

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -94,31 +94,37 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
9494
}
9595
};
9696

97+
LLVM_NODISCARD
9798
Expr *visit(Expr *E) {
9899
SetParentRAII SetParent(Walker, E);
99100
return inherited::visit(E);
100101
}
101102

103+
LLVM_NODISCARD
102104
Stmt *visit(Stmt *S) {
103105
SetParentRAII SetParent(Walker, S);
104106
return inherited::visit(S);
105107
}
106-
108+
109+
LLVM_NODISCARD
107110
Pattern *visit(Pattern *P) {
108111
SetParentRAII SetParent(Walker, P);
109112
return inherited::visit(P);
110113
}
111114

115+
LLVM_NODISCARD
112116
bool visit(Decl *D) {
113117
SetParentRAII SetParent(Walker, D);
114118
return inherited::visit(D);
115119
}
116-
120+
121+
LLVM_NODISCARD
117122
bool visit(TypeRepr *T) {
118123
SetParentRAII SetParent(Walker, T);
119124
return inherited::visit(T);
120125
}
121-
126+
127+
LLVM_NODISCARD
122128
bool visit(ParameterList *PL) {
123129
return inherited::visit(PL);
124130
}
@@ -127,17 +133,20 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
127133
// Decls
128134
//===--------------------------------------------------------------------===//
129135

136+
LLVM_NODISCARD
130137
bool visitGenericParamListIfNeeded(GenericContext *GC) {
131138
// Must check this first in case extensions have not been bound yet
132139
if (Walker.shouldWalkIntoGenericParams()) {
133140
if (auto *params = GC->getParsedGenericParams()) {
134-
doIt(params);
141+
if (doIt(params))
142+
return true;
135143
}
136144
return true;
137145
}
138146
return false;
139147
}
140148

149+
LLVM_NODISCARD
141150
bool visitTrailingRequirements(GenericContext *GC) {
142151
if (const auto Where = GC->getTrailingWhereClause()) {
143152
for (auto &Req: Where->getRequirements())
@@ -368,7 +377,9 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
368377
bool visitSubscriptDecl(SubscriptDecl *SD) {
369378
bool WalkGenerics = visitGenericParamListIfNeeded(SD);
370379

371-
visit(SD->getIndices());
380+
if (visit(SD->getIndices()))
381+
return true;
382+
372383
if (auto *const TyR = SD->getElementTypeRepr())
373384
if (doIt(TyR))
374385
return true;
@@ -399,9 +410,12 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
399410
// accessor generics are visited from the storage decl
400411
!isa<AccessorDecl>(AFD) && visitGenericParamListIfNeeded(AFD);
401412

402-
if (auto *PD = AFD->getImplicitSelfDecl(/*createIfNeeded=*/false))
403-
visit(PD);
404-
visit(AFD->getParameters());
413+
if (auto *PD = AFD->getImplicitSelfDecl(/*createIfNeeded=*/false)) {
414+
if (visit(PD))
415+
return true;
416+
}
417+
if (visit(AFD->getParameters()))
418+
return true;
405419

406420
if (auto *FD = dyn_cast<FuncDecl>(AFD)) {
407421
if (!isa<AccessorDecl>(FD))
@@ -436,7 +450,8 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
436450

437451
bool visitEnumElementDecl(EnumElementDecl *ED) {
438452
if (auto *PL = ED->getParameterList()) {
439-
visit(PL);
453+
if (visit(PL))
454+
return true;
440455
}
441456

442457
if (auto *rawLiteralExpr = ED->getRawValueUnchecked()) {
@@ -857,7 +872,8 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
857872
}
858873

859874
Expr *visitClosureExpr(ClosureExpr *expr) {
860-
visit(expr->getParameters());
875+
if (visit(expr->getParameters()))
876+
return nullptr;
861877

862878
if (expr->hasExplicitResultType()) {
863879
if (doIt(expr->getExplicitResultTypeRepr()))
@@ -1249,6 +1265,7 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
12491265
template <typename T>
12501266
using PostWalkResult = ASTWalker::PostWalkResult<T>;
12511267

1268+
LLVM_NODISCARD
12521269
bool traverse(PreWalkAction Pre, llvm::function_ref<bool(void)> VisitChildren,
12531270
llvm::function_ref<PostWalkAction(void)> WalkPost) {
12541271
switch (Pre.Action) {
@@ -1271,6 +1288,7 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
12711288
}
12721289

12731290
template <typename T>
1291+
LLVM_NODISCARD
12741292
T *traverse(PreWalkResult<T *> Pre,
12751293
llvm::function_ref<T *(T *)> VisitChildren,
12761294
llvm::function_ref<PostWalkResult<T *>(T *)> WalkPost) {
@@ -1299,6 +1317,7 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
12991317
llvm_unreachable("Unhandled case in switch!");
13001318
}
13011319

1320+
LLVM_NODISCARD
13021321
bool visitParameterList(ParameterList *PL) {
13031322
return traverse(
13041323
Walker.walkToParameterListPre(PL),
@@ -1316,13 +1335,15 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
13161335
public:
13171336
Traversal(ASTWalker &walker) : Walker(walker) {}
13181337

1338+
LLVM_NODISCARD
13191339
Expr *doIt(Expr *E) {
13201340
return traverse<Expr>(
13211341
Walker.walkToExprPre(E),
13221342
[&](Expr *E) { return visit(E); },
13231343
[&](Expr *E) { return Walker.walkToExprPost(E); });
13241344
}
13251345

1346+
LLVM_NODISCARD
13261347
Stmt *doIt(Stmt *S) {
13271348
return traverse<Stmt>(
13281349
Walker.walkToStmtPre(S),
@@ -1350,6 +1371,7 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
13501371
}
13511372

13521373
/// Returns true on failure.
1374+
LLVM_NODISCARD
13531375
bool doIt(Decl *D) {
13541376
if (shouldSkip(D))
13551377
return false;
@@ -1359,14 +1381,16 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
13591381
[&]() { return visit(D); },
13601382
[&]() { return Walker.walkToDeclPost(D); });
13611383
}
1362-
1384+
1385+
LLVM_NODISCARD
13631386
Pattern *doIt(Pattern *P) {
13641387
return traverse<Pattern>(
13651388
Walker.walkToPatternPre(P),
13661389
[&](Pattern *P) { return visit(P); },
13671390
[&](Pattern *P) { return Walker.walkToPatternPost(P); });
13681391
}
13691392

1393+
LLVM_NODISCARD
13701394
bool doIt(const StmtCondition &C) {
13711395
for (auto &elt : C) {
13721396
switch (elt.getKind()) {
@@ -1398,13 +1422,15 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
13981422
}
13991423

14001424
/// Returns true on failure.
1425+
LLVM_NODISCARD
14011426
bool doIt(TypeRepr *T) {
14021427
return traverse(
14031428
Walker.walkToTypeReprPre(T),
14041429
[&]() { return visit(T); },
14051430
[&]() { return Walker.walkToTypeReprPost(T); });
14061431
}
1407-
1432+
1433+
LLVM_NODISCARD
14081434
bool doIt(RequirementRepr &Req) {
14091435
switch (Req.getKind()) {
14101436
case RequirementReprKind::SameType:
@@ -1423,6 +1449,7 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
14231449
return false;
14241450
}
14251451

1452+
LLVM_NODISCARD
14261453
bool doIt(GenericParamList *GPL) {
14271454
// Visit generic params
14281455
for (auto &P : GPL->getParams()) {
@@ -1439,6 +1466,7 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
14391466
return false;
14401467
}
14411468

1469+
LLVM_NODISCARD
14421470
ArgumentList *visit(ArgumentList *ArgList) {
14431471
for (auto Idx : indices(*ArgList)) {
14441472
auto *E = doIt(ArgList->getExpr(Idx));
@@ -1448,6 +1476,7 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
14481476
return ArgList;
14491477
}
14501478

1479+
LLVM_NODISCARD
14511480
ArgumentList *doIt(ArgumentList *ArgList) {
14521481
return traverse<ArgumentList>(
14531482
Walker.walkToArgumentListPre(ArgList),
@@ -2001,12 +2030,12 @@ Pattern *Pattern::walk(ASTWalker &walker) {
20012030
}
20022031

20032032
TypeRepr *TypeRepr::walk(ASTWalker &walker) {
2004-
Traversal(walker).doIt(this);
2033+
(void)Traversal(walker).doIt(this);
20052034
return this;
20062035
}
20072036

20082037
StmtConditionElement *StmtConditionElement::walk(ASTWalker &walker) {
2009-
Traversal(walker).doIt(*this);
2038+
(void)Traversal(walker).doIt(*this);
20102039
return this;
20112040
}
20122041

0 commit comments

Comments
 (0)