Skip to content

Commit 86888ce

Browse files
authored
Merge pull request swiftlang#23499 from gottesmm/pr-7c99cd81f730147d0e6b637f864da32d1511a450
[parse/sema] Give all case bodies their own var decls without using t…
2 parents 927ef3c + b50d878 commit 86888ce

File tree

10 files changed

+228
-75
lines changed

10 files changed

+228
-75
lines changed

include/swift/AST/Decl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4715,7 +4715,7 @@ class VarDecl : public AbstractStorageDecl {
47154715
/// Set \p v to be the pattern produced VarDecl that is the parent of this
47164716
/// var decl.
47174717
void setParentVarDecl(VarDecl *v) {
4718-
assert(v);
4718+
assert(v && v != this);
47194719
Parent = v;
47204720
}
47214721

include/swift/AST/Stmt.h

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717
#ifndef SWIFT_AST_STMT_H
1818
#define SWIFT_AST_STMT_H
1919

20+
#include "swift/AST/ASTNode.h"
2021
#include "swift/AST/Availability.h"
2122
#include "swift/AST/AvailabilitySpec.h"
22-
#include "swift/AST/ASTNode.h"
2323
#include "swift/AST/IfConfigClause.h"
2424
#include "swift/AST/TypeAlignments.h"
2525
#include "swift/Basic/NullablePtr.h"
26+
#include "llvm/ADT/TinyPtrVector.h"
2627
#include "llvm/Support/TrailingObjects.h"
2728

2829
namespace swift {
@@ -984,24 +985,22 @@ class CaseStmt final
984985
SourceLoc CaseLoc;
985986
SourceLoc ColonLoc;
986987

987-
llvm::PointerIntPair<Stmt *, 1, bool> BodyAndHasBoundDecls;
988+
llvm::PointerIntPair<Stmt *, 1, bool> BodyAndHasFallthrough;
988989

989-
/// Set to true if we have a fallthrough.
990-
///
991-
/// TODO: Once we have CaseBodyVarDecls, use the bit in BodyAndHasBoundDecls
992-
/// for this instead. This is separate now for staging reasons.
993-
bool hasFallthrough;
990+
Optional<MutableArrayRef<VarDecl *>> CaseBodyVariables;
994991

995992
CaseStmt(SourceLoc CaseLoc, ArrayRef<CaseLabelItem> CaseLabelItems,
996-
bool HasBoundDecls, SourceLoc UnknownAttrLoc, SourceLoc ColonLoc,
997-
Stmt *Body, Optional<bool> Implicit,
993+
SourceLoc UnknownAttrLoc, SourceLoc ColonLoc, Stmt *Body,
994+
Optional<MutableArrayRef<VarDecl *>> CaseBodyVariables,
995+
Optional<bool> Implicit,
998996
NullablePtr<FallthroughStmt> fallthroughStmt);
999997

1000998
public:
1001999
static CaseStmt *
10021000
create(ASTContext &C, SourceLoc CaseLoc,
1003-
ArrayRef<CaseLabelItem> CaseLabelItems, bool HasBoundDecls,
1004-
SourceLoc UnknownAttrLoc, SourceLoc ColonLoc, Stmt *Body,
1001+
ArrayRef<CaseLabelItem> CaseLabelItems, SourceLoc UnknownAttrLoc,
1002+
SourceLoc ColonLoc, Stmt *Body,
1003+
Optional<MutableArrayRef<VarDecl *>> CaseBodyVariables,
10051004
Optional<bool> Implicit = None,
10061005
NullablePtr<FallthroughStmt> fallthroughStmt = nullptr);
10071006

@@ -1020,18 +1019,18 @@ class CaseStmt final
10201019
}
10211020

10221021
NullablePtr<CaseStmt> getFallthroughDest() {
1023-
if (!hasFallthrough)
1022+
if (!hasFallthroughDest())
10241023
return nullptr;
10251024
return (*getTrailingObjects<FallthroughStmt *>())->getFallthroughDest();
10261025
}
10271026

1028-
bool hasFallthroughDest() const { return hasFallthrough; }
1027+
bool hasFallthroughDest() const { return BodyAndHasFallthrough.getInt(); }
10291028

1030-
Stmt *getBody() const { return BodyAndHasBoundDecls.getPointer(); }
1031-
void setBody(Stmt *body) { BodyAndHasBoundDecls.setPointer(body); }
1029+
Stmt *getBody() const { return BodyAndHasFallthrough.getPointer(); }
1030+
void setBody(Stmt *body) { BodyAndHasFallthrough.setPointer(body); }
10321031

10331032
/// True if the case block declares any patterns with local variable bindings.
1034-
bool hasBoundDecls() const { return BodyAndHasBoundDecls.getInt(); }
1033+
bool hasBoundDecls() const { return CaseBodyVariables.hasValue(); }
10351034

10361035
/// Get the source location of the 'case' or 'default' of the first label.
10371036
SourceLoc getLoc() const { return CaseLoc; }
@@ -1056,14 +1055,38 @@ class CaseStmt final
10561055
return UnknownAttrLoc.isValid();
10571056
}
10581057

1058+
Optional<ArrayRef<VarDecl *>> getCaseBodyVariables() const {
1059+
if (!CaseBodyVariables)
1060+
return None;
1061+
ArrayRef<VarDecl *> a = *CaseBodyVariables;
1062+
return a;
1063+
}
1064+
1065+
Optional<MutableArrayRef<VarDecl *>> getCaseBodyVariables() {
1066+
return CaseBodyVariables;
1067+
}
1068+
1069+
ArrayRef<VarDecl *> getCaseBodyVariablesOrEmptyArray() const {
1070+
if (!CaseBodyVariables)
1071+
return ArrayRef<VarDecl *>();
1072+
ArrayRef<VarDecl *> a = *CaseBodyVariables;
1073+
return a;
1074+
}
1075+
1076+
MutableArrayRef<VarDecl *> getCaseBodyVariablesOrEmptyArray() {
1077+
if (!CaseBodyVariables)
1078+
return MutableArrayRef<VarDecl *>();
1079+
return *CaseBodyVariables;
1080+
}
1081+
10591082
static bool classof(const Stmt *S) { return S->getKind() == StmtKind::Case; }
10601083

10611084
size_t numTrailingObjects(OverloadToken<CaseLabelItem>) const {
10621085
return getNumCaseLabelItems();
10631086
}
10641087

10651088
size_t numTrailingObjects(OverloadToken<FallthroughStmt *>) const {
1066-
return hasFallthrough ? 1 : 0;
1089+
return hasFallthroughDest() ? 1 : 0;
10671090
}
10681091
};
10691092

lib/AST/ASTDumper.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1619,6 +1619,24 @@ class PrintStmt : public StmtVisitor<PrintStmt> {
16191619
printCommon(S, "case_stmt");
16201620
if (S->hasUnknownAttr())
16211621
OS << " @unknown";
1622+
1623+
if (auto caseBodyVars = S->getCaseBodyVariables()) {
1624+
OS << '\n';
1625+
OS.indent(Indent + 2);
1626+
PrintWithColorRAII(OS, ParenthesisColor) << '(';
1627+
PrintWithColorRAII(OS, StmtColor) << "case_body_variables";
1628+
OS << '\n';
1629+
for (auto *vd : *caseBodyVars) {
1630+
OS.indent(2);
1631+
// TODO: Printing a var decl does an Indent ... dump(vd) ... '\n'. We
1632+
// should see if we can factor this dumping so that the caller of
1633+
// printRec(VarDecl) has more control over the printing.
1634+
printRec(vd);
1635+
}
1636+
OS.indent(Indent + 2);
1637+
PrintWithColorRAII(OS, ParenthesisColor) << ')';
1638+
}
1639+
16221640
for (const auto &LabelItem : S->getCaseLabelItems()) {
16231641
OS << '\n';
16241642
OS.indent(Indent + 2);

lib/AST/Decl.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4921,11 +4921,22 @@ static bool isVarInPattern(const VarDecl *vd, Pattern *p) {
49214921
static Optional<std::pair<CaseStmt *, Pattern *>>
49224922
findParentPatternCaseStmtAndPattern(const VarDecl *inputVD) {
49234923
auto getMatchingPattern = [&](CaseStmt *cs) -> Pattern * {
4924+
// Check if inputVD is in our case body var decls if we have any. If we do,
4925+
// treat its pattern as our first case label item pattern.
4926+
for (auto *vd : cs->getCaseBodyVariablesOrEmptyArray()) {
4927+
if (vd == inputVD) {
4928+
return cs->getMutableCaseLabelItems().front().getPattern();
4929+
}
4930+
}
4931+
4932+
// Then check the rest of our case label items.
49244933
for (auto &item : cs->getMutableCaseLabelItems()) {
49254934
if (isVarInPattern(inputVD, item.getPattern())) {
49264935
return item.getPattern();
49274936
}
49284937
}
4938+
4939+
// Otherwise return false if we do not find anything.
49294940
return nullptr;
49304941
};
49314942

@@ -4959,9 +4970,16 @@ VarDecl *VarDecl::getCanonicalVarDecl() const {
49594970
if (!vd)
49604971
return cur;
49614972

4973+
#ifndef NDEBUG
4974+
// Make sure that we don't get into an infinite loop.
4975+
SmallPtrSet<VarDecl *, 8> visitedDecls;
4976+
visitedDecls.insert(vd);
4977+
visitedDecls.insert(cur);
4978+
#endif
49624979
while (vd) {
49634980
cur = vd;
49644981
vd = vd->getParentVarDecl();
4982+
assert((!vd || visitedDecls.insert(vd).second) && "Infinite loop ?!");
49654983
}
49664984

49674985
return cur;

lib/AST/Stmt.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -386,43 +386,50 @@ SourceLoc CaseLabelItem::getEndLoc() const {
386386
}
387387

388388
CaseStmt::CaseStmt(SourceLoc caseLoc, ArrayRef<CaseLabelItem> caseLabelItems,
389-
bool hasBoundDecls, SourceLoc unknownAttrLoc,
390-
SourceLoc colonLoc, Stmt *body, Optional<bool> implicit,
389+
SourceLoc unknownAttrLoc, SourceLoc colonLoc, Stmt *body,
390+
Optional<MutableArrayRef<VarDecl *>> caseBodyVariables,
391+
Optional<bool> implicit,
391392
NullablePtr<FallthroughStmt> fallthroughStmt)
392393
: Stmt(StmtKind::Case, getDefaultImplicitFlag(implicit, caseLoc)),
393394
UnknownAttrLoc(unknownAttrLoc), CaseLoc(caseLoc), ColonLoc(colonLoc),
394-
BodyAndHasBoundDecls(body, hasBoundDecls),
395-
hasFallthrough(fallthroughStmt.isNonNull()) {
395+
BodyAndHasFallthrough(body, fallthroughStmt.isNonNull()),
396+
CaseBodyVariables(caseBodyVariables) {
396397
Bits.CaseStmt.NumPatterns = caseLabelItems.size();
397398
assert(Bits.CaseStmt.NumPatterns > 0 &&
398399
"case block must have at least one pattern");
399400

400-
if (hasFallthrough) {
401+
if (hasFallthroughDest()) {
401402
*getTrailingObjects<FallthroughStmt *>() = fallthroughStmt.get();
402403
}
403404

404405
MutableArrayRef<CaseLabelItem> items{getTrailingObjects<CaseLabelItem>(),
405406
Bits.CaseStmt.NumPatterns};
406407

408+
// At the beginning mark all of our var decls as being owned by this
409+
// statement. In the typechecker we wireup the case stmt var decl list since
410+
// we know everything is lined up/typechecked then.
407411
for (unsigned i : range(Bits.CaseStmt.NumPatterns)) {
408412
new (&items[i]) CaseLabelItem(caseLabelItems[i]);
409413
items[i].getPattern()->markOwnedByStatement(this);
410414
}
415+
for (auto *vd : caseBodyVariables.getValueOr(MutableArrayRef<VarDecl *>())) {
416+
vd->setParentPatternStmt(this);
417+
}
411418
}
412419

413420
CaseStmt *CaseStmt::create(ASTContext &ctx, SourceLoc caseLoc,
414421
ArrayRef<CaseLabelItem> caseLabelItems,
415-
bool hasBoundDecls, SourceLoc unknownAttrLoc,
416-
SourceLoc colonLoc, Stmt *body,
422+
SourceLoc unknownAttrLoc, SourceLoc colonLoc,
423+
Stmt *body,
424+
Optional<MutableArrayRef<VarDecl *>> caseVarDecls,
417425
Optional<bool> implicit,
418426
NullablePtr<FallthroughStmt> fallthroughStmt) {
419427
void *mem =
420428
ctx.Allocate(totalSizeToAlloc<FallthroughStmt *, CaseLabelItem>(
421429
fallthroughStmt.isNonNull(), caseLabelItems.size()),
422430
alignof(CaseStmt));
423-
return ::new (mem)
424-
CaseStmt(caseLoc, caseLabelItems, hasBoundDecls, unknownAttrLoc, colonLoc,
425-
body, implicit, fallthroughStmt);
431+
return ::new (mem) CaseStmt(caseLoc, caseLabelItems, unknownAttrLoc, colonLoc,
432+
body, caseVarDecls, implicit, fallthroughStmt);
426433
}
427434

428435
SwitchStmt *SwitchStmt::create(LabeledStmtInfo LabelInfo, SourceLoc SwitchLoc,

lib/Parse/ParseStmt.cpp

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2295,10 +2295,11 @@ Parser::parseStmtCases(SmallVectorImpl<ASTNode> &cases, bool IsActive) {
22952295
return Status;
22962296
}
22972297

2298-
static ParserStatus parseStmtCase(Parser &P, SourceLoc &CaseLoc,
2299-
SmallVectorImpl<CaseLabelItem> &LabelItems,
2300-
SmallVectorImpl<VarDecl *> &BoundDecls,
2301-
SourceLoc &ColonLoc) {
2298+
static ParserStatus
2299+
parseStmtCase(Parser &P, SourceLoc &CaseLoc,
2300+
SmallVectorImpl<CaseLabelItem> &LabelItems,
2301+
SmallVectorImpl<VarDecl *> &BoundDecls, SourceLoc &ColonLoc,
2302+
Optional<MutableArrayRef<VarDecl *>> &CaseBodyDecls) {
23022303
SyntaxParsingContext CaseContext(P.SyntaxContext,
23032304
SyntaxKind::SwitchCaseLabel);
23042305
ParserStatus Status;
@@ -2314,13 +2315,28 @@ static ParserStatus parseStmtCase(Parser &P, SourceLoc &CaseLoc,
23142315
GuardedPattern PatternResult;
23152316
parseGuardedPattern(P, PatternResult, Status, BoundDecls,
23162317
GuardedPatternContext::Case, isFirst);
2317-
LabelItems.push_back(
2318-
CaseLabelItem(PatternResult.ThePattern, PatternResult.WhereLoc,
2319-
PatternResult.Guard));
2318+
LabelItems.emplace_back(PatternResult.ThePattern, PatternResult.WhereLoc,
2319+
PatternResult.Guard);
23202320
isFirst = false;
23212321
if (!P.consumeIf(tok::comma))
23222322
break;
23232323
}
2324+
2325+
// Grab the first case label item pattern and use it to initialize the case
2326+
// body var decls.
2327+
SmallVector<VarDecl *, 4> tmp;
2328+
LabelItems.front().getPattern()->collectVariables(tmp);
2329+
auto Result = P.Context.AllocateUninitialized<VarDecl *>(tmp.size());
2330+
for (unsigned i : indices(tmp)) {
2331+
auto *vOld = tmp[i];
2332+
auto *vNew = new (P.Context) VarDecl(
2333+
/*IsStatic*/ false, vOld->getSpecifier(), false /*IsCaptureList*/,
2334+
vOld->getNameLoc(), vOld->getName(), vOld->getDeclContext());
2335+
vNew->setHasNonPatternBindingInit();
2336+
vNew->setImplicit();
2337+
Result[i] = vNew;
2338+
}
2339+
CaseBodyDecls.emplace(Result);
23242340
}
23252341

23262342
ColonLoc = P.Tok.getLoc();
@@ -2448,9 +2464,10 @@ ParserResult<CaseStmt> Parser::parseStmtCase(bool IsActive) {
24482464

24492465
SourceLoc CaseLoc;
24502466
SourceLoc ColonLoc;
2467+
Optional<MutableArrayRef<VarDecl *>> CaseBodyDecls;
24512468
if (Tok.is(tok::kw_case)) {
2452-
Status |=
2453-
::parseStmtCase(*this, CaseLoc, CaseLabelItems, BoundDecls, ColonLoc);
2469+
Status |= ::parseStmtCase(*this, CaseLoc, CaseLabelItems, BoundDecls,
2470+
ColonLoc, CaseBodyDecls);
24542471
} else if (Tok.is(tok::kw_default)) {
24552472
Status |= parseStmtCaseDefault(*this, CaseLoc, CaseLabelItems, ColonLoc);
24562473
} else {
@@ -2480,10 +2497,9 @@ ParserResult<CaseStmt> Parser::parseStmtCase(bool IsActive) {
24802497
}
24812498

24822499
return makeParserResult(
2483-
Status,
2484-
CaseStmt::create(Context, CaseLoc, CaseLabelItems, !BoundDecls.empty(),
2485-
UnknownAttrLoc, ColonLoc, Body, None,
2486-
FallthroughFinder::findFallthrough(Body)));
2500+
Status, CaseStmt::create(Context, CaseLoc, CaseLabelItems, UnknownAttrLoc,
2501+
ColonLoc, Body, CaseBodyDecls, None,
2502+
FallthroughFinder::findFallthrough(Body)));
24872503
}
24882504

24892505
/// stmt-pound-assert:

lib/Sema/DerivedConformanceCodingKey.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,9 @@ deriveBodyCodingKey_enum_stringValue(AbstractFunctionDecl *strValDecl, void *) {
235235
auto *returnStmt = new (C) ReturnStmt(SourceLoc(), caseValue);
236236
auto *caseBody = BraceStmt::create(C, SourceLoc(), ASTNode(returnStmt),
237237
SourceLoc());
238-
cases.push_back(CaseStmt::create(C, SourceLoc(), labelItem,
239-
/*HasBoundDecls=*/false, SourceLoc(),
240-
SourceLoc(), caseBody));
238+
cases.push_back(CaseStmt::create(C, SourceLoc(), labelItem, SourceLoc(),
239+
SourceLoc(), caseBody,
240+
/*case body var decls*/ None));
241241
}
242242

243243
auto *selfRef = DerivedConformance::createSelfDeclRef(strValDecl);
@@ -303,9 +303,9 @@ deriveBodyCodingKey_init_stringValue(AbstractFunctionDecl *initDecl, void *) {
303303

304304
auto *body = BraceStmt::create(C, SourceLoc(), ASTNode(assignment),
305305
SourceLoc());
306-
cases.push_back(CaseStmt::create(C, SourceLoc(), labelItem,
307-
/*HasBoundDecls=*/false, SourceLoc(),
308-
SourceLoc(), body));
306+
cases.push_back(CaseStmt::create(C, SourceLoc(), labelItem, SourceLoc(),
307+
SourceLoc(), body,
308+
/*case body var decls*/ None));
309309
}
310310

311311
auto *anyPat = new (C) AnyPattern(SourceLoc());
@@ -315,9 +315,9 @@ deriveBodyCodingKey_init_stringValue(AbstractFunctionDecl *initDecl, void *) {
315315
auto *dfltReturnStmt = new (C) FailStmt(SourceLoc(), SourceLoc());
316316
auto *dfltBody = BraceStmt::create(C, SourceLoc(), ASTNode(dfltReturnStmt),
317317
SourceLoc());
318-
cases.push_back(CaseStmt::create(C, SourceLoc(), dfltLabelItem,
319-
/*HasBoundDecls=*/false, SourceLoc(),
320-
SourceLoc(), dfltBody));
318+
cases.push_back(CaseStmt::create(C, SourceLoc(), dfltLabelItem, SourceLoc(),
319+
SourceLoc(), dfltBody,
320+
/*case body var decls*/ None));
321321

322322
auto *stringValueDecl = initDecl->getParameters()->get(0);
323323
auto *stringValueRef = new (C) DeclRefExpr(stringValueDecl, DeclNameLoc(),

0 commit comments

Comments
 (0)