Skip to content

Commit 79fe1b3

Browse files
committed
[AST] Remove findParentPatternCaseStmtAndPattern
Add the extra logic to `VarDecl::getParentPattern` necessary to handle fallthrough and case body variables instead. This also changes the behavior for case body vars - previously we would return the first pattern in the CaseStmt, but that's not necessarily correct. Instead, return the first pattern that actually binds the variable.
1 parent 413824c commit 79fe1b3

File tree

1 file changed

+23
-58
lines changed

1 file changed

+23
-58
lines changed

lib/AST/Decl.cpp

Lines changed: 23 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -8159,50 +8159,6 @@ SourceRange AbstractStorageDecl::getTypeSourceRangeForDiagnostics() const {
81598159
return SourceRange();
81608160
}
81618161

8162-
static std::optional<std::pair<CaseStmt *, Pattern *>>
8163-
findParentPatternCaseStmtAndPattern(const VarDecl *inputVD) {
8164-
auto getMatchingPattern = [&](CaseStmt *cs) -> Pattern * {
8165-
// Check if inputVD is in our case body var decls if we have any. If we do,
8166-
// treat its pattern as our first case label item pattern.
8167-
for (auto *vd : cs->getCaseBodyVariables()) {
8168-
if (vd == inputVD) {
8169-
return cs->getMutableCaseLabelItems().front().getPattern();
8170-
}
8171-
}
8172-
8173-
// Then check the rest of our case label items.
8174-
for (auto &item : cs->getMutableCaseLabelItems()) {
8175-
if (item.getPattern()->containsVarDecl(inputVD)) {
8176-
return item.getPattern();
8177-
}
8178-
}
8179-
8180-
// Otherwise return false if we do not find anything.
8181-
return nullptr;
8182-
};
8183-
8184-
// First find our canonical var decl. This is the VarDecl corresponding to the
8185-
// first case label item of the first case block in the fallthrough chain that
8186-
// our case block is within. Grab the case stmt associated with that var decl
8187-
// and start traveling down the fallthrough chain looking for the case
8188-
// statement that the input VD belongs to by using getMatchingPattern().
8189-
auto *canonicalVD = inputVD->getCanonicalVarDecl();
8190-
auto *caseStmt =
8191-
dyn_cast_or_null<CaseStmt>(canonicalVD->getParentPatternStmt());
8192-
if (!caseStmt)
8193-
return std::nullopt;
8194-
8195-
if (auto *p = getMatchingPattern(caseStmt))
8196-
return std::make_pair(caseStmt, p);
8197-
8198-
while ((caseStmt = caseStmt->getFallthroughDest().getPtrOrNull())) {
8199-
if (auto *p = getMatchingPattern(caseStmt))
8200-
return std::make_pair(caseStmt, p);
8201-
}
8202-
8203-
return std::nullopt;
8204-
}
8205-
82068162
VarDecl *VarDecl::getCanonicalVarDecl() const {
82078163
// Any var decl without a parent var decl is canonical. This means that before
82088164
// type checking, all var decls are canonical.
@@ -8247,17 +8203,34 @@ Pattern *VarDecl::getParentPattern() const {
82478203
}
82488204

82498205
// If this is a statement parent, dig the pattern out of it.
8250-
if (auto *stmt = getParentPatternStmt()) {
8206+
const auto *canonicalVD = getCanonicalVarDecl();
8207+
if (auto *stmt = canonicalVD->getParentPatternStmt()) {
82518208
if (auto *FES = dyn_cast<ForEachStmt>(stmt))
82528209
return FES->getPattern();
82538210

82548211
if (auto *cs = dyn_cast<CaseStmt>(stmt)) {
8255-
// In a case statement, search for the pattern that contains it. This is
8256-
// a bit silly, because you can't have something like "case x, y:" anyway.
8257-
for (auto items : cs->getCaseLabelItems()) {
8258-
if (items.getPattern()->containsVarDecl(this))
8259-
return items.getPattern();
8212+
// In a case statement, search for the pattern that contains it.
8213+
auto findPattern = [](CaseStmt *cs, const VarDecl *VD) -> Pattern * {
8214+
for (auto items : cs->getCaseLabelItems()) {
8215+
if (items.getPattern()->containsVarDecl(VD))
8216+
return items.getPattern();
8217+
}
8218+
return nullptr;
8219+
};
8220+
if (auto *P = findPattern(cs, this))
8221+
return P;
8222+
8223+
// If it's not in the CaseStmt, check its fallthrough destination.
8224+
if (auto fallthrough = cs->getFallthroughDest()) {
8225+
if (auto *P = findPattern(fallthrough.get(), this))
8226+
return P;
82608227
}
8228+
8229+
// Finally, check the canonical variable, this is necessary to correctly
8230+
// handle case body vars, we just want to take the first pattern that
8231+
// declares it in that case.
8232+
if (auto *P = findPattern(cs, canonicalVD))
8233+
return P;
82618234
}
82628235

82638236
if (auto *LCS = dyn_cast<LabeledConditionalStmt>(stmt)) {
@@ -8268,14 +8241,6 @@ Pattern *VarDecl::getParentPattern() const {
82688241
}
82698242
}
82708243

8271-
// Otherwise, check if we have to walk our case stmt's var decl list to find
8272-
// the pattern.
8273-
if (auto caseStmtPatternPair = findParentPatternCaseStmtAndPattern(this)) {
8274-
return caseStmtPatternPair->second;
8275-
}
8276-
8277-
// Otherwise, this is a case we do not know or understand. Return nullptr to
8278-
// signal we do not have any information.
82798244
return nullptr;
82808245
}
82818246

0 commit comments

Comments
 (0)