@@ -8159,50 +8159,6 @@ SourceRange AbstractStorageDecl::getTypeSourceRangeForDiagnostics() const {
8159
8159
return SourceRange ();
8160
8160
}
8161
8161
8162
- static std::optional<std::pair<CaseStmt *, Pattern *>>
8163
- findParentPatternCaseStmtAndPattern (const VarDecl *inputVD) {
8164
- auto getMatchingPattern = [&](CaseStmt *cs) -> Pattern * {
8165
- // Check if inputVD is in our case body var decls if we have any. If we do,
8166
- // treat its pattern as our first case label item pattern.
8167
- for (auto *vd : cs->getCaseBodyVariables ()) {
8168
- if (vd == inputVD) {
8169
- return cs->getMutableCaseLabelItems ().front ().getPattern ();
8170
- }
8171
- }
8172
-
8173
- // Then check the rest of our case label items.
8174
- for (auto &item : cs->getMutableCaseLabelItems ()) {
8175
- if (item.getPattern ()->containsVarDecl (inputVD)) {
8176
- return item.getPattern ();
8177
- }
8178
- }
8179
-
8180
- // Otherwise return false if we do not find anything.
8181
- return nullptr ;
8182
- };
8183
-
8184
- // First find our canonical var decl. This is the VarDecl corresponding to the
8185
- // first case label item of the first case block in the fallthrough chain that
8186
- // our case block is within. Grab the case stmt associated with that var decl
8187
- // and start traveling down the fallthrough chain looking for the case
8188
- // statement that the input VD belongs to by using getMatchingPattern().
8189
- auto *canonicalVD = inputVD->getCanonicalVarDecl ();
8190
- auto *caseStmt =
8191
- dyn_cast_or_null<CaseStmt>(canonicalVD->getParentPatternStmt ());
8192
- if (!caseStmt)
8193
- return std::nullopt ;
8194
-
8195
- if (auto *p = getMatchingPattern (caseStmt))
8196
- return std::make_pair (caseStmt, p);
8197
-
8198
- while ((caseStmt = caseStmt->getFallthroughDest ().getPtrOrNull ())) {
8199
- if (auto *p = getMatchingPattern (caseStmt))
8200
- return std::make_pair (caseStmt, p);
8201
- }
8202
-
8203
- return std::nullopt ;
8204
- }
8205
-
8206
8162
VarDecl *VarDecl::getCanonicalVarDecl () const {
8207
8163
// Any var decl without a parent var decl is canonical. This means that before
8208
8164
// type checking, all var decls are canonical.
@@ -8247,17 +8203,34 @@ Pattern *VarDecl::getParentPattern() const {
8247
8203
}
8248
8204
8249
8205
// If this is a statement parent, dig the pattern out of it.
8250
- if (auto *stmt = getParentPatternStmt ()) {
8206
+ const auto *canonicalVD = getCanonicalVarDecl ();
8207
+ if (auto *stmt = canonicalVD->getParentPatternStmt ()) {
8251
8208
if (auto *FES = dyn_cast<ForEachStmt>(stmt))
8252
8209
return FES->getPattern ();
8253
8210
8254
8211
if (auto *cs = dyn_cast<CaseStmt>(stmt)) {
8255
- // In a case statement, search for the pattern that contains it. This is
8256
- // a bit silly, because you can't have something like "case x, y:" anyway.
8257
- for (auto items : cs->getCaseLabelItems ()) {
8258
- if (items.getPattern ()->containsVarDecl (this ))
8259
- return items.getPattern ();
8212
+ // In a case statement, search for the pattern that contains it.
8213
+ auto findPattern = [](CaseStmt *cs, const VarDecl *VD) -> Pattern * {
8214
+ for (auto items : cs->getCaseLabelItems ()) {
8215
+ if (items.getPattern ()->containsVarDecl (VD))
8216
+ return items.getPattern ();
8217
+ }
8218
+ return nullptr ;
8219
+ };
8220
+ if (auto *P = findPattern (cs, this ))
8221
+ return P;
8222
+
8223
+ // If it's not in the CaseStmt, check its fallthrough destination.
8224
+ if (auto fallthrough = cs->getFallthroughDest ()) {
8225
+ if (auto *P = findPattern (fallthrough.get (), this ))
8226
+ return P;
8260
8227
}
8228
+
8229
+ // Finally, check the canonical variable, this is necessary to correctly
8230
+ // handle case body vars, we just want to take the first pattern that
8231
+ // declares it in that case.
8232
+ if (auto *P = findPattern (cs, canonicalVD))
8233
+ return P;
8261
8234
}
8262
8235
8263
8236
if (auto *LCS = dyn_cast<LabeledConditionalStmt>(stmt)) {
@@ -8268,14 +8241,6 @@ Pattern *VarDecl::getParentPattern() const {
8268
8241
}
8269
8242
}
8270
8243
8271
- // Otherwise, check if we have to walk our case stmt's var decl list to find
8272
- // the pattern.
8273
- if (auto caseStmtPatternPair = findParentPatternCaseStmtAndPattern (this )) {
8274
- return caseStmtPatternPair->second ;
8275
- }
8276
-
8277
- // Otherwise, this is a case we do not know or understand. Return nullptr to
8278
- // signal we do not have any information.
8279
8244
return nullptr ;
8280
8245
}
8281
8246
0 commit comments