Skip to content

Commit 872176b

Browse files
authored
Merge pull request #84149 from hamishknight/case-and-pat
A few pattern cleanups + fixes
2 parents b740a45 + 10ed175 commit 872176b

29 files changed

+227
-362
lines changed

include/swift/AST/Stmt.h

Lines changed: 19 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,9 @@ class alignas(8) Stmt : public ASTAllocated<Stmt> {
8484
NumElements : 32
8585
);
8686

87-
SWIFT_INLINE_BITFIELD_FULL(CaseStmt, Stmt, 32,
87+
SWIFT_INLINE_BITFIELD_FULL(CaseStmt, Stmt, 16+32,
8888
: NumPadBits,
89+
NumCaseBodyVars : 16,
8990
NumPatterns : 32
9091
);
9192

@@ -1210,8 +1211,8 @@ enum CaseParentKind { Switch, DoCatch };
12101211
///
12111212
class CaseStmt final
12121213
: public Stmt,
1213-
private llvm::TrailingObjects<CaseStmt, FallthroughStmt *,
1214-
CaseLabelItem> {
1214+
private llvm::TrailingObjects<CaseStmt, FallthroughStmt *, CaseLabelItem,
1215+
VarDecl *> {
12151216
friend TrailingObjects;
12161217

12171218
Stmt *ParentStmt = nullptr;
@@ -1222,15 +1223,17 @@ class CaseStmt final
12221223

12231224
llvm::PointerIntPair<BraceStmt *, 1, bool> BodyAndHasFallthrough;
12241225

1225-
std::optional<MutableArrayRef<VarDecl *>> CaseBodyVariables;
1226-
12271226
CaseStmt(CaseParentKind ParentKind, SourceLoc ItemIntroducerLoc,
12281227
ArrayRef<CaseLabelItem> CaseLabelItems, SourceLoc UnknownAttrLoc,
12291228
SourceLoc ItemTerminatorLoc, BraceStmt *Body,
1230-
std::optional<MutableArrayRef<VarDecl *>> CaseBodyVariables,
1231-
std::optional<bool> Implicit,
1229+
ArrayRef<VarDecl *> CaseBodyVariables, std::optional<bool> Implicit,
12321230
NullablePtr<FallthroughStmt> fallthroughStmt);
12331231

1232+
MutableArrayRef<VarDecl *> getCaseBodyVariablesBuffer() {
1233+
return {getTrailingObjects<VarDecl *>(),
1234+
static_cast<size_t>(Bits.CaseStmt.NumCaseBodyVars)};
1235+
}
1236+
12341237
public:
12351238
/// Create a parsed 'case'/'default' for 'switch' statement.
12361239
static CaseStmt *
@@ -1244,13 +1247,17 @@ class CaseStmt final
12441247
ArrayRef<CaseLabelItem> CaseLabelItems,
12451248
BraceStmt *Body);
12461249

1250+
static CaseStmt *
1251+
createImplicit(ASTContext &ctx, CaseParentKind parentKind,
1252+
ArrayRef<CaseLabelItem> caseLabelItems, BraceStmt *body,
1253+
NullablePtr<FallthroughStmt> fallthroughStmt = nullptr);
1254+
12471255
static CaseStmt *
12481256
create(ASTContext &C, CaseParentKind ParentKind, SourceLoc ItemIntroducerLoc,
12491257
ArrayRef<CaseLabelItem> CaseLabelItems, SourceLoc UnknownAttrLoc,
12501258
SourceLoc ItemTerminatorLoc, BraceStmt *Body,
1251-
std::optional<MutableArrayRef<VarDecl *>> CaseBodyVariables,
1252-
std::optional<bool> Implicit = std::nullopt,
1253-
NullablePtr<FallthroughStmt> fallthroughStmt = nullptr);
1259+
ArrayRef<VarDecl *> CaseBodyVariables, std::optional<bool> Implicit,
1260+
NullablePtr<FallthroughStmt> fallthroughStmt);
12541261

12551262
CaseParentKind getParentKind() const { return ParentKind; }
12561263

@@ -1293,7 +1300,7 @@ class CaseStmt final
12931300
void setBody(BraceStmt *body) { BodyAndHasFallthrough.setPointer(body); }
12941301

12951302
/// True if the case block declares any patterns with local variable bindings.
1296-
bool hasBoundDecls() const { return CaseBodyVariables.has_value(); }
1303+
bool hasCaseBodyVariables() const { return !getCaseBodyVariables().empty(); }
12971304

12981305
/// Get the source location of the 'case', 'default', or 'catch' of the first
12991306
/// label.
@@ -1345,38 +1352,8 @@ class CaseStmt final
13451352
}
13461353

13471354
/// Return an ArrayRef containing the case body variables of this CaseStmt.
1348-
///
1349-
/// Asserts if case body variables was not explicitly initialized. In contexts
1350-
/// where one wants a non-asserting version, \see
1351-
/// getCaseBodyVariablesOrEmptyArray.
13521355
ArrayRef<VarDecl *> getCaseBodyVariables() const {
1353-
ArrayRef<VarDecl *> a = *CaseBodyVariables;
1354-
return a;
1355-
}
1356-
1357-
bool hasCaseBodyVariables() const { return CaseBodyVariables.has_value(); }
1358-
1359-
/// Return an MutableArrayRef containing the case body variables of this
1360-
/// CaseStmt.
1361-
///
1362-
/// Asserts if case body variables was not explicitly initialized. In contexts
1363-
/// where one wants a non-asserting version, \see
1364-
/// getCaseBodyVariablesOrEmptyArray.
1365-
MutableArrayRef<VarDecl *> getCaseBodyVariables() {
1366-
return *CaseBodyVariables;
1367-
}
1368-
1369-
ArrayRef<VarDecl *> getCaseBodyVariablesOrEmptyArray() const {
1370-
if (!CaseBodyVariables)
1371-
return ArrayRef<VarDecl *>();
1372-
ArrayRef<VarDecl *> a = *CaseBodyVariables;
1373-
return a;
1374-
}
1375-
1376-
MutableArrayRef<VarDecl *> getCaseBodyVariablesOrEmptyArray() {
1377-
if (!CaseBodyVariables)
1378-
return MutableArrayRef<VarDecl *>();
1379-
return *CaseBodyVariables;
1356+
return const_cast<CaseStmt *>(this)->getCaseBodyVariablesBuffer();
13801357
}
13811358

13821359
/// Find the next case statement within the same 'switch' or 'do-catch',

lib/AST/ASTScopeLookup.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -378,10 +378,10 @@ bool CaseLabelItemScope::lookupLocalsOrMembers(DeclConsumer consumer) const {
378378
}
379379

380380
bool CaseStmtBodyScope::lookupLocalsOrMembers(DeclConsumer consumer) const {
381-
for (auto *var : stmt->getCaseBodyVariablesOrEmptyArray())
381+
for (auto *var : stmt->getCaseBodyVariables()) {
382382
if (consumer.consume({var}))
383-
return true;
384-
383+
return true;
384+
}
385385
return false;
386386
}
387387

lib/AST/ASTVerifier.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2795,7 +2795,7 @@ class Verifier : public ASTWalker {
27952795
// guarantee that all case label items bind corresponding patterns and
27962796
// the case body var decls of a case stmt are created from the var decls
27972797
// of the first case label items.
2798-
if (!caseStmt->hasBoundDecls()) {
2798+
if (!caseStmt->hasCaseBodyVariables()) {
27992799
Out << "parent CaseStmt of VarDecl does not have any case body "
28002800
"decls?!\n";
28012801
abort();

lib/AST/Decl.cpp

Lines changed: 25 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -8159,50 +8159,6 @@ SourceRange AbstractStorageDecl::getTypeSourceRangeForDiagnostics() const {
81598159
return SourceRange();
81608160
}
81618161

8162-
static std::optional<std::pair<CaseStmt *, Pattern *>>
8163-
findParentPatternCaseStmtAndPattern(const VarDecl *inputVD) {
8164-
auto getMatchingPattern = [&](CaseStmt *cs) -> Pattern * {
8165-
// Check if inputVD is in our case body var decls if we have any. If we do,
8166-
// treat its pattern as our first case label item pattern.
8167-
for (auto *vd : cs->getCaseBodyVariablesOrEmptyArray()) {
8168-
if (vd == inputVD) {
8169-
return cs->getMutableCaseLabelItems().front().getPattern();
8170-
}
8171-
}
8172-
8173-
// Then check the rest of our case label items.
8174-
for (auto &item : cs->getMutableCaseLabelItems()) {
8175-
if (item.getPattern()->containsVarDecl(inputVD)) {
8176-
return item.getPattern();
8177-
}
8178-
}
8179-
8180-
// Otherwise return false if we do not find anything.
8181-
return nullptr;
8182-
};
8183-
8184-
// First find our canonical var decl. This is the VarDecl corresponding to the
8185-
// first case label item of the first case block in the fallthrough chain that
8186-
// our case block is within. Grab the case stmt associated with that var decl
8187-
// and start traveling down the fallthrough chain looking for the case
8188-
// statement that the input VD belongs to by using getMatchingPattern().
8189-
auto *canonicalVD = inputVD->getCanonicalVarDecl();
8190-
auto *caseStmt =
8191-
dyn_cast_or_null<CaseStmt>(canonicalVD->getParentPatternStmt());
8192-
if (!caseStmt)
8193-
return std::nullopt;
8194-
8195-
if (auto *p = getMatchingPattern(caseStmt))
8196-
return std::make_pair(caseStmt, p);
8197-
8198-
while ((caseStmt = caseStmt->getFallthroughDest().getPtrOrNull())) {
8199-
if (auto *p = getMatchingPattern(caseStmt))
8200-
return std::make_pair(caseStmt, p);
8201-
}
8202-
8203-
return std::nullopt;
8204-
}
8205-
82068162
VarDecl *VarDecl::getCanonicalVarDecl() const {
82078163
// Any var decl without a parent var decl is canonical. This means that before
82088164
// type checking, all var decls are canonical.
@@ -8227,16 +8183,7 @@ VarDecl *VarDecl::getCanonicalVarDecl() const {
82278183
}
82288184

82298185
Stmt *VarDecl::getRecursiveParentPatternStmt() const {
8230-
// If our parent is already a pattern stmt, just return that.
8231-
if (auto *stmt = getParentPatternStmt())
8232-
return stmt;
8233-
8234-
// Otherwise, see if we have a parent var decl. If we do not, then return
8235-
// nullptr. Otherwise, return the case stmt that we found.
8236-
auto result = findParentPatternCaseStmtAndPattern(this);
8237-
if (!result.has_value())
8238-
return nullptr;
8239-
return result->first;
8186+
return getCanonicalVarDecl()->getParentPatternStmt();
82408187
}
82418188

82428189
/// Return the Pattern involved in initializing this VarDecl. Recall that the
@@ -8256,17 +8203,34 @@ Pattern *VarDecl::getParentPattern() const {
82568203
}
82578204

82588205
// If this is a statement parent, dig the pattern out of it.
8259-
if (auto *stmt = getParentPatternStmt()) {
8206+
const auto *canonicalVD = getCanonicalVarDecl();
8207+
if (auto *stmt = canonicalVD->getParentPatternStmt()) {
82608208
if (auto *FES = dyn_cast<ForEachStmt>(stmt))
82618209
return FES->getPattern();
82628210

82638211
if (auto *cs = dyn_cast<CaseStmt>(stmt)) {
8264-
// In a case statement, search for the pattern that contains it. This is
8265-
// a bit silly, because you can't have something like "case x, y:" anyway.
8266-
for (auto items : cs->getCaseLabelItems()) {
8267-
if (items.getPattern()->containsVarDecl(this))
8268-
return items.getPattern();
8212+
// In a case statement, search for the pattern that contains it.
8213+
auto findPattern = [](CaseStmt *cs, const VarDecl *VD) -> Pattern * {
8214+
for (auto items : cs->getCaseLabelItems()) {
8215+
if (items.getPattern()->containsVarDecl(VD))
8216+
return items.getPattern();
8217+
}
8218+
return nullptr;
8219+
};
8220+
if (auto *P = findPattern(cs, this))
8221+
return P;
8222+
8223+
// If it's not in the CaseStmt, check its fallthrough destination.
8224+
if (auto fallthrough = cs->getFallthroughDest()) {
8225+
if (auto *P = findPattern(fallthrough.get(), this))
8226+
return P;
82698227
}
8228+
8229+
// Finally, check the canonical variable, this is necessary to correctly
8230+
// handle case body vars, we just want to take the first pattern that
8231+
// declares it in that case.
8232+
if (auto *P = findPattern(cs, canonicalVD))
8233+
return P;
82708234
}
82718235

82728236
if (auto *LCS = dyn_cast<LabeledConditionalStmt>(stmt)) {
@@ -8277,14 +8241,6 @@ Pattern *VarDecl::getParentPattern() const {
82778241
}
82788242
}
82798243

8280-
// Otherwise, check if we have to walk our case stmt's var decl list to find
8281-
// the pattern.
8282-
if (auto caseStmtPatternPair = findParentPatternCaseStmtAndPattern(this)) {
8283-
return caseStmtPatternPair->second;
8284-
}
8285-
8286-
// Otherwise, this is a case we do not know or understand. Return nullptr to
8287-
// signal we do not have any information.
82888244
return nullptr;
82898245
}
82908246

@@ -8345,7 +8301,7 @@ bool VarDecl::isCaseBodyVariable() const {
83458301
auto *caseStmt = dyn_cast_or_null<CaseStmt>(getRecursiveParentPatternStmt());
83468302
if (!caseStmt)
83478303
return false;
8348-
return llvm::any_of(caseStmt->getCaseBodyVariablesOrEmptyArray(),
8304+
return llvm::any_of(caseStmt->getCaseBodyVariables(),
83498305
[&](VarDecl *vd) { return vd == this; });
83508306
}
83518307

lib/AST/Pattern.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,10 +208,14 @@ namespace {
208208
return Action::Continue(P);
209209
}
210210

211-
// Only walk into an expression insofar as it doesn't open a new scope -
212-
// that is, don't walk into a closure body.
213211
PreWalkResult<Expr *> walkToExprPre(Expr *E) override {
214-
if (isa<ClosureExpr>(E)) {
212+
// Only walk into an expression insofar as it doesn't open a new scope -
213+
// that is, don't walk into a closure body, TapExpr, or
214+
// SingleValueStmtExpr. Also don't walk into key paths since any nested
215+
// VarDecls are invalid there, and after being diagnosed by key path
216+
// resolution the ASTWalker won't visit them.
217+
if (isa<ClosureExpr>(E) || isa<TapExpr>(E) ||
218+
isa<SingleValueStmtExpr>(E) || isa<KeyPathExpr>(E)) {
215219
return Action::SkipNode(E);
216220
}
217221
return Action::Continue(E);

0 commit comments

Comments
 (0)