Skip to content

Commit d260fbb

Browse files
committed
ASTScope: More accurate modeling of 'case' statements
In a code snippet like the following, static func ==(a: Foo, b: Foo) -> Bool { switch (a, b) { case (.x(let aa), .x(let bb)) where condition(aa, bb), (.y(let aa), .y(let bb)) where condition(aa, bb): return aa == bb default: return false } } The CaseStmt defines two patterns, both of which bind 'aa' and 'bb'. The first 'aa'/'bb' are in scope inside the first 'where' clause, and the second 'aa'/'bb' are in scope inside the second 'where' clause. Furthermore, the parser creates a "fake" VarDecl for 'aa' and 'bb' to represent the phi node merging the two values along the two control flow paths; these are in scope inside the body. Model this situation by introducing a new CaseLabelItemScope for the 'where' clauses, and a CaseStmtBodyScope for the body.
1 parent fdf1882 commit d260fbb

File tree

5 files changed

+131
-25
lines changed

5 files changed

+131
-25
lines changed

include/swift/AST/ASTScope.h

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ class ASTScopeImpl {
529529
// A local binding is a basically a local variable defined in that very scope
530530
// It is not an instance variable or inherited type.
531531

532-
static bool lookupLocalBindingsInPattern(Pattern *p,
532+
static bool lookupLocalBindingsInPattern(const Pattern *p,
533533
DeclVisibilityKind vis,
534534
DeclConsumer consumer);
535535

@@ -1933,6 +1933,42 @@ class ForEachPatternScope final : public ASTScopeImpl {
19331933
bool isLabeledStmtLookupTerminator() const override;
19341934
};
19351935

1936+
/// The parent scope for a 'case' statement, consisting of zero or more
1937+
/// CaseLabelItemScopes, followed by a CaseStmtBodyScope.
1938+
///
1939+
/// +------------------------------------------------------------------
1940+
/// | CaseStmtScope
1941+
/// +------------------------------------------------------------------
1942+
/// | +--------------------------+
1943+
/// | | CaseLabelItemScope: |
1944+
/// | +--------------------------+
1945+
/// | case .foo(let x, let y) where | condition(x, y), |
1946+
/// | ^------^--------------------^--^ |
1947+
/// | this guard expression sees first 'x'/'y' |
1948+
/// | +--------------------------+
1949+
/// |
1950+
/// | +--------------------------+
1951+
/// | | CaseLabelItemScope: |
1952+
/// | +--------------------------+
1953+
/// | .foo(let x, let y) where | condition(x, y), |
1954+
/// | ^------^--------------------^--^ |
1955+
/// | this guard expression sees second 'x'/'y' |
1956+
/// | +--------------------------+
1957+
/// |
1958+
/// | .bar(let x, let y)
1959+
/// | this case label item doesn't have a guard, so no
1960+
/// | scope is created.
1961+
/// |
1962+
/// | +----------------------------------------------------------------
1963+
/// | | CaseStmtBodyScope:
1964+
/// | +----------------------------------------------------------------
1965+
/// | | {
1966+
/// | | ... x, y <-- body sees "joined" 'x'/'y' created by parser
1967+
/// | | }
1968+
/// | +----------------------------------------------------------------
1969+
/// |
1970+
/// +------------------------------------------------------------------
1971+
19361972
class CaseStmtScope final : public AbstractStmtScope {
19371973
public:
19381974
CaseStmt *const stmt;
@@ -1945,17 +1981,66 @@ class CaseStmtScope final : public AbstractStmtScope {
19451981
private:
19461982
void expandAScopeThatDoesNotCreateANewInsertionPoint(ScopeCreator &);
19471983

1984+
public:
1985+
std::string getClassName() const override;
1986+
Stmt *getStmt() const override { return stmt; }
1987+
};
1988+
1989+
/// The scope used for the guard expression in a case statement. Any
1990+
/// variables bound by the case label item's pattern are visible in
1991+
/// this scope.
1992+
class CaseLabelItemScope final : public ASTScopeImpl {
1993+
public:
1994+
CaseLabelItem item;
1995+
CaseLabelItemScope(const CaseLabelItem &item) : item(item) {}
1996+
virtual ~CaseLabelItemScope() {}
1997+
1998+
protected:
1999+
ASTScopeImpl *expandSpecifically(ScopeCreator &scopeCreator) override;
2000+
2001+
private:
2002+
void expandAScopeThatDoesNotCreateANewInsertionPoint(ScopeCreator &);
2003+
19482004
public:
19492005
std::string getClassName() const override;
19502006
SourceRange
19512007
getSourceRangeOfThisASTNode(bool omitAssertions = false) const override;
1952-
Stmt *getStmt() const override { return stmt; }
19532008

19542009
protected:
19552010
bool lookupLocalsOrMembers(ArrayRef<const ASTScopeImpl *>,
19562011
ASTScopeImpl::DeclConsumer) const override;
19572012
};
19582013

2014+
/// The scope used for the body of a 'case' statement.
2015+
///
2016+
/// If the 'case' statement has multiple case label items, each label
2017+
/// item's pattern must bind the same variables; the parser creates
2018+
/// "fake" variables to represent the join of the variables bound by
2019+
/// each pattern.
2020+
///
2021+
/// These "fake" variables are visible in the 'case' statement body.
2022+
class CaseStmtBodyScope final : public ASTScopeImpl {
2023+
public:
2024+
CaseStmt *const stmt;
2025+
CaseStmtBodyScope(CaseStmt *e) : stmt(e) {}
2026+
virtual ~CaseStmtBodyScope() {}
2027+
2028+
protected:
2029+
ASTScopeImpl *expandSpecifically(ScopeCreator &scopeCreator) override;
2030+
2031+
private:
2032+
void expandAScopeThatDoesNotCreateANewInsertionPoint(ScopeCreator &);
2033+
2034+
public:
2035+
std::string getClassName() const override;
2036+
SourceRange
2037+
getSourceRangeOfThisASTNode(bool omitAssertions = false) const override;
2038+
protected:
2039+
bool lookupLocalsOrMembers(ArrayRef<const ASTScopeImpl *>,
2040+
ASTScopeImpl::DeclConsumer) const override;
2041+
bool isLabeledStmtLookupTerminator() const override;
2042+
};
2043+
19592044
class BraceStmtScope final : public AbstractStmtScope {
19602045

19612046
public:

lib/AST/ASTScope.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,8 @@ DEFINE_GET_CLASS_NAME(SwitchStmtScope)
253253
DEFINE_GET_CLASS_NAME(ForEachStmtScope)
254254
DEFINE_GET_CLASS_NAME(ForEachPatternScope)
255255
DEFINE_GET_CLASS_NAME(CaseStmtScope)
256+
DEFINE_GET_CLASS_NAME(CaseLabelItemScope)
257+
DEFINE_GET_CLASS_NAME(CaseStmtBodyScope)
256258
DEFINE_GET_CLASS_NAME(BraceStmtScope)
257259

258260
#undef DEFINE_GET_CLASS_NAME

lib/AST/ASTScopeCreation.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,6 +1169,8 @@ NO_NEW_INSERTION_POINT(EnumElementScope)
11691169

11701170
NO_NEW_INSERTION_POINT(CaptureListScope)
11711171
NO_NEW_INSERTION_POINT(CaseStmtScope)
1172+
NO_NEW_INSERTION_POINT(CaseLabelItemScope)
1173+
NO_NEW_INSERTION_POINT(CaseStmtBodyScope)
11721174
NO_NEW_INSERTION_POINT(ClosureBodyScope)
11731175
NO_NEW_INSERTION_POINT(DefaultArgumentInitializerScope)
11741176
NO_NEW_INSERTION_POINT(DoStmtScope)
@@ -1486,10 +1488,24 @@ void ForEachPatternScope::expandAScopeThatDoesNotCreateANewInsertionPoint(
14861488

14871489
void CaseStmtScope::expandAScopeThatDoesNotCreateANewInsertionPoint(
14881490
ScopeCreator &scopeCreator) {
1489-
for (auto &caseItem : stmt->getMutableCaseLabelItems())
1490-
scopeCreator.addToScopeTree(caseItem.getGuardExpr(), this);
1491+
for (auto &item : stmt->getCaseLabelItems()) {
1492+
if (item.getGuardExpr()) {
1493+
scopeCreator.constructExpandAndInsertUncheckable<CaseLabelItemScope>(
1494+
this, item);
1495+
}
1496+
}
1497+
1498+
scopeCreator.constructExpandAndInsertUncheckable<CaseStmtBodyScope>(
1499+
this, stmt);
1500+
}
14911501

1492-
// Add a child for the case body.
1502+
void CaseLabelItemScope::expandAScopeThatDoesNotCreateANewInsertionPoint(
1503+
ScopeCreator &scopeCreator) {
1504+
scopeCreator.addToScopeTree(item.getGuardExpr(), this);
1505+
}
1506+
1507+
void CaseStmtBodyScope::expandAScopeThatDoesNotCreateANewInsertionPoint(
1508+
ScopeCreator &scopeCreator) {
14931509
scopeCreator.addToScopeTree(stmt->getBody(), this);
14941510
}
14951511

lib/AST/ASTScopeLookup.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -422,12 +422,18 @@ bool ForEachPatternScope::lookupLocalsOrMembers(ArrayRef<const ASTScopeImpl *>,
422422
stmt->getPattern(), DeclVisibilityKind::LocalVariable, consumer);
423423
}
424424

425-
bool CaseStmtScope::lookupLocalsOrMembers(ArrayRef<const ASTScopeImpl *>,
426-
DeclConsumer consumer) const {
427-
for (auto &item : stmt->getMutableCaseLabelItems())
428-
if (lookupLocalBindingsInPattern(
429-
item.getPattern(), DeclVisibilityKind::LocalVariable, consumer))
430-
return true;
425+
bool CaseLabelItemScope::lookupLocalsOrMembers(ArrayRef<const ASTScopeImpl *>,
426+
DeclConsumer consumer) const {
427+
return lookupLocalBindingsInPattern(
428+
item.getPattern(), DeclVisibilityKind::LocalVariable, consumer);
429+
}
430+
431+
bool CaseStmtBodyScope::lookupLocalsOrMembers(ArrayRef<const ASTScopeImpl *>,
432+
DeclConsumer consumer) const {
433+
for (auto *var : stmt->getCaseBodyVariablesOrEmptyArray())
434+
if (consumer.consume({var}, DeclVisibilityKind::LocalVariable))
435+
return true;
436+
431437
return false;
432438
}
433439

@@ -554,7 +560,7 @@ bool ConditionalClausePatternUseScope::lookupLocalsOrMembers(
554560
pattern, DeclVisibilityKind::LocalVariable, consumer);
555561
}
556562

557-
bool ASTScopeImpl::lookupLocalBindingsInPattern(Pattern *p,
563+
bool ASTScopeImpl::lookupLocalBindingsInPattern(const Pattern *p,
558564
DeclVisibilityKind vis,
559565
DeclConsumer consumer) {
560566
if (!p)
@@ -893,6 +899,10 @@ bool ForEachPatternScope::isLabeledStmtLookupTerminator() const {
893899
return false;
894900
}
895901

902+
bool CaseStmtBodyScope::isLabeledStmtLookupTerminator() const {
903+
return false;
904+
}
905+
896906
llvm::SmallVector<LabeledStmt *, 4>
897907
ASTScopeImpl::lookupLabeledStmts(SourceFile *sourceFile, SourceLoc loc) {
898908
// Find the innermost scope from which to start our search.

lib/AST/ASTScopeSourceRange.cpp

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -408,20 +408,13 @@ SourceRange ForEachPatternScope::getSourceRangeOfThisASTNode(
408408
}
409409

410410
SourceRange
411-
CaseStmtScope::getSourceRangeOfThisASTNode(const bool omitAssertions) const {
412-
// The scope of the case statement begins at the first guard expression,
413-
// if there is one, and extends to the end of the body.
414-
// FIXME: Figure out what to do about multiple pattern bindings. We might
415-
// want a more restrictive rule in those cases.
416-
for (const auto &caseItem : stmt->getCaseLabelItems()) {
417-
if (auto guardExpr = caseItem.getGuardExpr())
418-
return SourceRange(guardExpr->getStartLoc(),
419-
stmt->getBody()->getEndLoc());
420-
}
411+
CaseLabelItemScope::getSourceRangeOfThisASTNode(const bool omitAssertions) const {
412+
return item.getGuardExpr()->getSourceRange();
413+
}
421414

422-
// Otherwise, it covers the body.
423-
return stmt->getBody()
424-
->getSourceRange(); // The scope of the case statement begins
415+
SourceRange
416+
CaseStmtBodyScope::getSourceRangeOfThisASTNode(const bool omitAssertions) const {
417+
return stmt->getBody()->getSourceRange();
425418
}
426419

427420
SourceRange

0 commit comments

Comments
 (0)