@@ -8159,50 +8159,6 @@ SourceRange AbstractStorageDecl::getTypeSourceRangeForDiagnostics() const {
8159
8159
return SourceRange ();
8160
8160
}
8161
8161
8162
- static std::optional<std::pair<CaseStmt *, Pattern *>>
8163
- findParentPatternCaseStmtAndPattern (const VarDecl *inputVD) {
8164
- auto getMatchingPattern = [&](CaseStmt *cs) -> Pattern * {
8165
- // Check if inputVD is in our case body var decls if we have any. If we do,
8166
- // treat its pattern as our first case label item pattern.
8167
- for (auto *vd : cs->getCaseBodyVariablesOrEmptyArray ()) {
8168
- if (vd == inputVD) {
8169
- return cs->getMutableCaseLabelItems ().front ().getPattern ();
8170
- }
8171
- }
8172
-
8173
- // Then check the rest of our case label items.
8174
- for (auto &item : cs->getMutableCaseLabelItems ()) {
8175
- if (item.getPattern ()->containsVarDecl (inputVD)) {
8176
- return item.getPattern ();
8177
- }
8178
- }
8179
-
8180
- // Otherwise return false if we do not find anything.
8181
- return nullptr ;
8182
- };
8183
-
8184
- // First find our canonical var decl. This is the VarDecl corresponding to the
8185
- // first case label item of the first case block in the fallthrough chain that
8186
- // our case block is within. Grab the case stmt associated with that var decl
8187
- // and start traveling down the fallthrough chain looking for the case
8188
- // statement that the input VD belongs to by using getMatchingPattern().
8189
- auto *canonicalVD = inputVD->getCanonicalVarDecl ();
8190
- auto *caseStmt =
8191
- dyn_cast_or_null<CaseStmt>(canonicalVD->getParentPatternStmt ());
8192
- if (!caseStmt)
8193
- return std::nullopt ;
8194
-
8195
- if (auto *p = getMatchingPattern (caseStmt))
8196
- return std::make_pair (caseStmt, p);
8197
-
8198
- while ((caseStmt = caseStmt->getFallthroughDest ().getPtrOrNull ())) {
8199
- if (auto *p = getMatchingPattern (caseStmt))
8200
- return std::make_pair (caseStmt, p);
8201
- }
8202
-
8203
- return std::nullopt ;
8204
- }
8205
-
8206
8162
VarDecl *VarDecl::getCanonicalVarDecl () const {
8207
8163
// Any var decl without a parent var decl is canonical. This means that before
8208
8164
// type checking, all var decls are canonical.
@@ -8227,16 +8183,7 @@ VarDecl *VarDecl::getCanonicalVarDecl() const {
8227
8183
}
8228
8184
8229
8185
Stmt *VarDecl::getRecursiveParentPatternStmt () const {
8230
- // If our parent is already a pattern stmt, just return that.
8231
- if (auto *stmt = getParentPatternStmt ())
8232
- return stmt;
8233
-
8234
- // Otherwise, see if we have a parent var decl. If we do not, then return
8235
- // nullptr. Otherwise, return the case stmt that we found.
8236
- auto result = findParentPatternCaseStmtAndPattern (this );
8237
- if (!result.has_value ())
8238
- return nullptr ;
8239
- return result->first ;
8186
+ return getCanonicalVarDecl ()->getParentPatternStmt ();
8240
8187
}
8241
8188
8242
8189
// / Return the Pattern involved in initializing this VarDecl. Recall that the
@@ -8256,17 +8203,34 @@ Pattern *VarDecl::getParentPattern() const {
8256
8203
}
8257
8204
8258
8205
// If this is a statement parent, dig the pattern out of it.
8259
- if (auto *stmt = getParentPatternStmt ()) {
8206
+ const auto *canonicalVD = getCanonicalVarDecl ();
8207
+ if (auto *stmt = canonicalVD->getParentPatternStmt ()) {
8260
8208
if (auto *FES = dyn_cast<ForEachStmt>(stmt))
8261
8209
return FES->getPattern ();
8262
8210
8263
8211
if (auto *cs = dyn_cast<CaseStmt>(stmt)) {
8264
- // In a case statement, search for the pattern that contains it. This is
8265
- // a bit silly, because you can't have something like "case x, y:" anyway.
8266
- for (auto items : cs->getCaseLabelItems ()) {
8267
- if (items.getPattern ()->containsVarDecl (this ))
8268
- return items.getPattern ();
8212
+ // In a case statement, search for the pattern that contains it.
8213
+ auto findPattern = [](CaseStmt *cs, const VarDecl *VD) -> Pattern * {
8214
+ for (auto items : cs->getCaseLabelItems ()) {
8215
+ if (items.getPattern ()->containsVarDecl (VD))
8216
+ return items.getPattern ();
8217
+ }
8218
+ return nullptr ;
8219
+ };
8220
+ if (auto *P = findPattern (cs, this ))
8221
+ return P;
8222
+
8223
+ // If it's not in the CaseStmt, check its fallthrough destination.
8224
+ if (auto fallthrough = cs->getFallthroughDest ()) {
8225
+ if (auto *P = findPattern (fallthrough.get (), this ))
8226
+ return P;
8269
8227
}
8228
+
8229
+ // Finally, check the canonical variable, this is necessary to correctly
8230
+ // handle case body vars, we just want to take the first pattern that
8231
+ // declares it in that case.
8232
+ if (auto *P = findPattern (cs, canonicalVD))
8233
+ return P;
8270
8234
}
8271
8235
8272
8236
if (auto *LCS = dyn_cast<LabeledConditionalStmt>(stmt)) {
@@ -8277,14 +8241,6 @@ Pattern *VarDecl::getParentPattern() const {
8277
8241
}
8278
8242
}
8279
8243
8280
- // Otherwise, check if we have to walk our case stmt's var decl list to find
8281
- // the pattern.
8282
- if (auto caseStmtPatternPair = findParentPatternCaseStmtAndPattern (this )) {
8283
- return caseStmtPatternPair->second ;
8284
- }
8285
-
8286
- // Otherwise, this is a case we do not know or understand. Return nullptr to
8287
- // signal we do not have any information.
8288
8244
return nullptr ;
8289
8245
}
8290
8246
@@ -8345,7 +8301,7 @@ bool VarDecl::isCaseBodyVariable() const {
8345
8301
auto *caseStmt = dyn_cast_or_null<CaseStmt>(getRecursiveParentPatternStmt ());
8346
8302
if (!caseStmt)
8347
8303
return false ;
8348
- return llvm::any_of (caseStmt->getCaseBodyVariablesOrEmptyArray (),
8304
+ return llvm::any_of (caseStmt->getCaseBodyVariables (),
8349
8305
[&](VarDecl *vd) { return vd == this ; });
8350
8306
}
8351
8307
0 commit comments