Skip to content

Commit 6714cef

Browse files
authored
Merge pull request swiftlang#33254 from DougGregor/stmt-checker-astscope
[Statement checker] Leverage ASTScope to eliminate recursive walk
2 parents 2680b0c + cc17aef commit 6714cef

File tree

8 files changed

+465
-227
lines changed

8 files changed

+465
-227
lines changed

include/swift/AST/ASTScope.h

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ class ASTScopeImpl {
195195

196196
/// Get ride of descendants and remove them from scopedNodes so the scopes
197197
/// can be recreated. Needed because typechecking inserts a return statment
198-
/// into intiailizers.
198+
/// into initializers.
199199
void disownDescendants(ScopeCreator &);
200200

201201
public: // for addReusedBodyScopes
@@ -415,10 +415,17 @@ class ASTScopeImpl {
415415
unqualifiedLookup(SourceFile *, DeclNameRef, SourceLoc,
416416
const DeclContext *startingContext, DeclConsumer);
417417

418+
/// Entry point into ASTScopeImpl-land for labeled statement lookups.
419+
static llvm::SmallVector<LabeledStmt *, 4>
420+
lookupLabeledStmts(SourceFile *sourceFile, SourceLoc loc);
421+
418422
static Optional<bool>
419423
computeIsCascadingUse(ArrayRef<const ASTScopeImpl *> history,
420424
Optional<bool> initialIsCascadingUse);
421425

426+
static std::pair<CaseStmt *, CaseStmt *>
427+
lookupFallthroughSourceAndDest(SourceFile *sourceFile, SourceLoc loc);
428+
422429
#pragma mark - - lookup- starting point
423430
private:
424431
static const ASTScopeImpl *findStartingScopeForLookup(SourceFile *,
@@ -538,6 +545,12 @@ class ASTScopeImpl {
538545

539546
NullablePtr<const ASTScopeImpl>
540547
ancestorWithDeclSatisfying(function_ref<bool(const Decl *)> predicate) const;
548+
549+
/// Whether this scope terminates lookup of labeled statements in the
550+
/// children below it, because one cannot perform a "break" or a "continue"
551+
/// in a child that goes outside of this scope.
552+
virtual bool isLabeledStmtLookupTerminator() const;
553+
541554
}; // end of ASTScopeImpl
542555

543556
#pragma mark - specific scope classes
@@ -1357,6 +1370,9 @@ class ConditionalClauseScope final : public ASTScopeImpl {
13571370
private:
13581371
ArrayRef<StmtConditionElement> getCond() const;
13591372
const StmtConditionElement &getStmtConditionElement() const;
1373+
1374+
protected:
1375+
bool isLabeledStmtLookupTerminator() const override;
13601376
};
13611377

13621378
/// If, while, & guard statements all start with a conditional clause, then some
@@ -1380,6 +1396,7 @@ class ConditionalClausePatternUseScope final : public ASTScopeImpl {
13801396
bool lookupLocalsOrMembers(ArrayRef<const ASTScopeImpl *>,
13811397
DeclConsumer) const override;
13821398
void printSpecifics(llvm::raw_ostream &out) const override;
1399+
bool isLabeledStmtLookupTerminator() const override;
13831400
};
13841401

13851402

@@ -1713,17 +1730,16 @@ class AbstractStmtScope : public ASTScopeImpl {
17131730
virtual Stmt *getStmt() const = 0;
17141731
NullablePtr<Stmt> getStmtIfAny() const override { return getStmt(); }
17151732
NullablePtr<const void> getReferrent() const override;
1733+
1734+
protected:
1735+
bool isLabeledStmtLookupTerminator() const override;
17161736
};
17171737

17181738
class LabeledConditionalStmtScope : public AbstractStmtScope {
17191739
public:
17201740
Stmt *getStmt() const override;
17211741
virtual LabeledConditionalStmt *getLabeledConditionalStmt() const = 0;
17221742

1723-
/// If a condition is present, create the martuska.
1724-
/// Return the lookupParent for the use scope.
1725-
ASTScopeImpl *createCondScopes();
1726-
17271743
protected:
17281744
/// Return the lookupParent required to search these.
17291745
ASTScopeImpl *createNestedConditionalClauseScopes(ScopeCreator &,
@@ -1807,6 +1823,7 @@ class LookupParentDiversionScope final : public ASTScopeImpl {
18071823
NullablePtr<const ASTScopeImpl> getLookupParent() const override {
18081824
return lookupParent;
18091825
}
1826+
bool isLabeledStmtLookupTerminator() const override;
18101827
};
18111828

18121829
class RepeatWhileScope final : public AbstractStmtScope {
@@ -1826,6 +1843,23 @@ class RepeatWhileScope final : public AbstractStmtScope {
18261843
Stmt *getStmt() const override { return stmt; }
18271844
};
18281845

1846+
class DoStmtScope final : public AbstractStmtScope {
1847+
public:
1848+
DoStmt *const stmt;
1849+
DoStmtScope(DoStmt *e) : stmt(e) {}
1850+
virtual ~DoStmtScope() {}
1851+
1852+
protected:
1853+
ASTScopeImpl *expandSpecifically(ScopeCreator &scopeCreator) override;
1854+
1855+
private:
1856+
void expandAScopeThatDoesNotCreateANewInsertionPoint(ScopeCreator &);
1857+
1858+
public:
1859+
std::string getClassName() const override;
1860+
Stmt *getStmt() const override { return stmt; }
1861+
};
1862+
18291863
class DoCatchStmtScope final : public AbstractStmtScope {
18301864
public:
18311865
DoCatchStmt *const stmt;
@@ -1897,6 +1931,7 @@ class ForEachPatternScope final : public ASTScopeImpl {
18971931
protected:
18981932
bool lookupLocalsOrMembers(ArrayRef<const ASTScopeImpl *>,
18991933
DeclConsumer) const override;
1934+
bool isLabeledStmtLookupTerminator() const override;
19001935
};
19011936

19021937
class CaseStmtScope final : public AbstractStmtScope {

include/swift/AST/AnyFunctionRef.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,20 @@ class AnyFunctionRef {
5353
}
5454
}
5555

56+
/// Construct an AnyFunctionRef from a decl context that might be
57+
/// some sort of function.
58+
static Optional<AnyFunctionRef> fromDeclContext(DeclContext *dc) {
59+
if (auto fn = dyn_cast<AbstractFunctionDecl>(dc)) {
60+
return AnyFunctionRef(fn);
61+
}
62+
63+
if (auto ace = dyn_cast<AbstractClosureExpr>(dc)) {
64+
return AnyFunctionRef(ace);
65+
}
66+
67+
return None;
68+
}
69+
5670
CaptureInfo getCaptureInfo() const {
5771
if (auto *AFD = TheFunction.dyn_cast<AbstractFunctionDecl *>())
5872
return AFD->getCaptureInfo();

include/swift/AST/NameLookup.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,31 @@ class ASTScope {
706706
computeIsCascadingUse(ArrayRef<const ast_scope::ASTScopeImpl *> history,
707707
Optional<bool> initialIsCascadingUse);
708708

709+
/// Entry point to record the visible statement labels from the given
710+
/// point.
711+
///
712+
/// This lookup only considers labels that are visible within the current
713+
/// function, so it will not return any labels from lexical scopes that
714+
/// are not reachable via labeled control flow.
715+
///
716+
/// \returns the set of labeled statements visible from the given source
717+
/// location, with the innermost labeled statement first and proceeding
718+
/// to the outermost labeled statement.
719+
static llvm::SmallVector<LabeledStmt *, 4>
720+
lookupLabeledStmts(SourceFile *sourceFile, SourceLoc loc);
721+
722+
/// Look for the directly enclosing case statement and the next case
723+
/// statement, which together act as the source and destination for a
724+
/// 'fallthrough' statement within a switch case.
725+
///
726+
/// \returns a pair (fallthrough source, fallthrough dest). If the location
727+
/// is not within the body of a case statement at all, the fallthrough
728+
/// source will be \c nullptr. If there is a fallthrough source that case is
729+
/// the last one, the fallthrough destination will be \c nullptr. A
730+
/// well-formed 'fallthrough' statement has both a source and destination.
731+
static std::pair<CaseStmt *, CaseStmt *>
732+
lookupFallthroughSourceAndDest(SourceFile *sourceFile, SourceLoc loc);
733+
709734
SWIFT_DEBUG_DUMP;
710735
void print(llvm::raw_ostream &) const;
711736
void dumpOneScopeMapLocation(std::pair<unsigned, unsigned>);

lib/AST/ASTScope.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,16 @@ Optional<bool> ASTScope::computeIsCascadingUse(
5454
return ASTScopeImpl::computeIsCascadingUse(history, initialIsCascadingUse);
5555
}
5656

57+
llvm::SmallVector<LabeledStmt *, 4> ASTScope::lookupLabeledStmts(
58+
SourceFile *sourceFile, SourceLoc loc) {
59+
return ASTScopeImpl::lookupLabeledStmts(sourceFile, loc);
60+
}
61+
62+
std::pair<CaseStmt *, CaseStmt *> ASTScope::lookupFallthroughSourceAndDest(
63+
SourceFile *sourceFile, SourceLoc loc) {
64+
return ASTScopeImpl::lookupFallthroughSourceAndDest(sourceFile, loc);
65+
}
66+
5767
#if SWIFT_COMPILER_IS_MSVC
5868
#pragma warning(push)
5969
#pragma warning(disable : 4996)
@@ -237,6 +247,7 @@ DEFINE_GET_CLASS_NAME(WhileStmtScope)
237247
DEFINE_GET_CLASS_NAME(GuardStmtScope)
238248
DEFINE_GET_CLASS_NAME(LookupParentDiversionScope)
239249
DEFINE_GET_CLASS_NAME(RepeatWhileScope)
250+
DEFINE_GET_CLASS_NAME(DoStmtScope)
240251
DEFINE_GET_CLASS_NAME(DoCatchStmtScope)
241252
DEFINE_GET_CLASS_NAME(SwitchStmtScope)
242253
DEFINE_GET_CLASS_NAME(ForEachStmtScope)

lib/AST/ASTScopeCreation.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,7 @@ class NodeAdder
866866
VISIT_AND_CREATE(IfStmt, IfStmtScope)
867867
VISIT_AND_CREATE(WhileStmt, WhileStmtScope)
868868
VISIT_AND_CREATE(RepeatWhileStmt, RepeatWhileScope)
869+
VISIT_AND_CREATE(DoStmt, DoStmtScope)
869870
VISIT_AND_CREATE(DoCatchStmt, DoCatchStmtScope)
870871
VISIT_AND_CREATE(SwitchStmt, SwitchStmtScope)
871872
VISIT_AND_CREATE(ForEachStmt, ForEachStmtScope)
@@ -908,11 +909,6 @@ class NodeAdder
908909
ScopeCreator &scopeCreator) {
909910
return scopeCreator.ifUniqueConstructExpandAndInsert<GuardStmtScope>(p, e);
910911
}
911-
NullablePtr<ASTScopeImpl> visitDoStmt(DoStmt *ds, ASTScopeImpl *p,
912-
ScopeCreator &scopeCreator) {
913-
scopeCreator.addToScopeTreeAndReturnInsertionPoint(ds->getBody(), p);
914-
return p; // Don't put subsequent decls inside the "do"
915-
}
916912
NullablePtr<ASTScopeImpl> visitTopLevelCodeDecl(TopLevelCodeDecl *d,
917913
ASTScopeImpl *p,
918914
ScopeCreator &scopeCreator) {
@@ -1204,6 +1200,7 @@ NO_NEW_INSERTION_POINT(CaptureListScope)
12041200
NO_NEW_INSERTION_POINT(CaseStmtScope)
12051201
NO_NEW_INSERTION_POINT(ClosureBodyScope)
12061202
NO_NEW_INSERTION_POINT(DefaultArgumentInitializerScope)
1203+
NO_NEW_INSERTION_POINT(DoStmtScope)
12071204
NO_NEW_INSERTION_POINT(DoCatchStmtScope)
12081205
NO_NEW_INSERTION_POINT(ForEachPatternScope)
12091206
NO_NEW_INSERTION_POINT(ForEachStmtScope)
@@ -1470,6 +1467,11 @@ void RepeatWhileScope::expandAScopeThatDoesNotCreateANewInsertionPoint(
14701467
scopeCreator.addToScopeTree(stmt->getCond(), this);
14711468
}
14721469

1470+
void DoStmtScope::expandAScopeThatDoesNotCreateANewInsertionPoint(
1471+
ScopeCreator &scopeCreator) {
1472+
scopeCreator.addToScopeTree(stmt->getBody(), this);
1473+
}
1474+
14731475
void DoCatchStmtScope::expandAScopeThatDoesNotCreateANewInsertionPoint(
14741476
ScopeCreator &scopeCreator) {
14751477
scopeCreator.addToScopeTree(stmt->getBody(), this);

lib/AST/ASTScopeLookup.cpp

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,3 +852,108 @@ bool isLocWithinAnInactiveClause(const SourceLoc loc, SourceFile *SF) {
852852
SF->walk(tester);
853853
return tester.wasFoundWithinInactiveClause;
854854
}
855+
856+
#pragma mark isLabeledStmtLookupTerminator implementations
857+
bool ASTScopeImpl::isLabeledStmtLookupTerminator() const {
858+
return true;
859+
}
860+
861+
bool LookupParentDiversionScope::isLabeledStmtLookupTerminator() const {
862+
return false;
863+
}
864+
865+
bool ConditionalClauseScope::isLabeledStmtLookupTerminator() const {
866+
return false;
867+
}
868+
869+
bool ConditionalClausePatternUseScope::isLabeledStmtLookupTerminator() const {
870+
return false;
871+
}
872+
873+
bool AbstractStmtScope::isLabeledStmtLookupTerminator() const {
874+
return false;
875+
}
876+
877+
bool ForEachPatternScope::isLabeledStmtLookupTerminator() const {
878+
return false;
879+
}
880+
881+
llvm::SmallVector<LabeledStmt *, 4>
882+
ASTScopeImpl::lookupLabeledStmts(SourceFile *sourceFile, SourceLoc loc) {
883+
// Find the innermost scope from which to start our search.
884+
auto *const fileScope = sourceFile->getScope().impl;
885+
const auto *innermost = fileScope->findInnermostEnclosingScope(loc, nullptr);
886+
ASTScopeAssert(innermost->getWasExpanded(),
887+
"If looking in a scope, it must have been expanded.");
888+
889+
llvm::SmallVector<LabeledStmt *, 4> labeledStmts;
890+
for (auto scope = innermost; scope && !scope->isLabeledStmtLookupTerminator();
891+
scope = scope->getParent().getPtrOrNull()) {
892+
// If we have a labeled statement, record it.
893+
auto stmt = scope->getStmtIfAny();
894+
if (!stmt) continue;
895+
896+
auto labeledStmt = dyn_cast<LabeledStmt>(stmt.get());
897+
if (!labeledStmt) continue;
898+
899+
// Skip guard statements; they aren't actually targets for break or
900+
// continue.
901+
if (isa<GuardStmt>(labeledStmt)) continue;
902+
903+
labeledStmts.push_back(labeledStmt);
904+
}
905+
906+
return labeledStmts;
907+
}
908+
909+
std::pair<CaseStmt *, CaseStmt *> ASTScopeImpl::lookupFallthroughSourceAndDest(
910+
SourceFile *sourceFile, SourceLoc loc) {
911+
// Find the innermost scope from which to start our search.
912+
auto *const fileScope = sourceFile->getScope().impl;
913+
const auto *innermost = fileScope->findInnermostEnclosingScope(loc, nullptr);
914+
ASTScopeAssert(innermost->getWasExpanded(),
915+
"If looking in a scope, it must have been expanded.");
916+
917+
// Look for the enclosing case statement and its 'switch' statement.
918+
CaseStmt *fallthroughSource = nullptr;
919+
SwitchStmt *switchStmt = nullptr;
920+
for (auto scope = innermost; scope && !scope->isLabeledStmtLookupTerminator();
921+
scope = scope->getParent().getPtrOrNull()) {
922+
// If we have a case statement, record it.
923+
auto stmt = scope->getStmtIfAny();
924+
if (!stmt) continue;
925+
926+
// If we've found the first case statement of a switch, record it as the
927+
// fallthrough source. do-catch statements don't support fallthrough.
928+
if (auto caseStmt = dyn_cast<CaseStmt>(stmt.get())) {
929+
if (!fallthroughSource &&
930+
caseStmt->getParentKind() == CaseParentKind::Switch)
931+
fallthroughSource = caseStmt;
932+
933+
continue;
934+
}
935+
936+
// If we've found the first switch statement, record it and we're done.
937+
switchStmt = dyn_cast<SwitchStmt>(stmt.get());
938+
if (switchStmt)
939+
break;
940+
}
941+
942+
// If we don't have both a fallthrough source and a switch statement
943+
// enclosing it, the 'fallthrough' statement is ill-formed.
944+
if (!fallthroughSource || !switchStmt)
945+
return { nullptr, nullptr };
946+
947+
// Find this case in the list of cases for the switch. If we don't find it
948+
// here, it means that the case isn't directly nested inside the switch, so
949+
// the case and fallthrough are both ill-formed.
950+
auto caseIter = llvm::find(switchStmt->getCases(), fallthroughSource);
951+
if (caseIter == switchStmt->getCases().end())
952+
return { nullptr, nullptr };
953+
954+
// Move along to the next case. This is the fallthrough destination.
955+
++caseIter;
956+
auto fallthroughDest = caseIter == switchStmt->getCases().end() ? nullptr
957+
: *caseIter;
958+
return { fallthroughSource, fallthroughDest };
959+
}

0 commit comments

Comments
 (0)