Skip to content

Commit dc13b1f

Browse files
committed
[AST] Tail-allocate case body variables on CaseStmt
1 parent c02c69a commit dc13b1f

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

include/swift/AST/Stmt.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,9 @@ class alignas(8) Stmt : public ASTAllocated<Stmt> {
8484
NumElements : 32
8585
);
8686

87-
SWIFT_INLINE_BITFIELD_FULL(CaseStmt, Stmt, 32,
87+
SWIFT_INLINE_BITFIELD_FULL(CaseStmt, Stmt, 16+32,
8888
: NumPadBits,
89+
NumCaseBodyVars : 16,
8990
NumPatterns : 32
9091
);
9192

@@ -1210,8 +1211,8 @@ enum CaseParentKind { Switch, DoCatch };
12101211
///
12111212
class CaseStmt final
12121213
: public Stmt,
1213-
private llvm::TrailingObjects<CaseStmt, FallthroughStmt *,
1214-
CaseLabelItem> {
1214+
private llvm::TrailingObjects<CaseStmt, FallthroughStmt *, CaseLabelItem,
1215+
VarDecl *> {
12151216
friend TrailingObjects;
12161217

12171218
Stmt *ParentStmt = nullptr;
@@ -1222,14 +1223,17 @@ class CaseStmt final
12221223

12231224
llvm::PointerIntPair<BraceStmt *, 1, bool> BodyAndHasFallthrough;
12241225

1225-
ArrayRef<VarDecl *> CaseBodyVariables;
1226-
12271226
CaseStmt(CaseParentKind ParentKind, SourceLoc ItemIntroducerLoc,
12281227
ArrayRef<CaseLabelItem> CaseLabelItems, SourceLoc UnknownAttrLoc,
12291228
SourceLoc ItemTerminatorLoc, BraceStmt *Body,
12301229
ArrayRef<VarDecl *> CaseBodyVariables, std::optional<bool> Implicit,
12311230
NullablePtr<FallthroughStmt> fallthroughStmt);
12321231

1232+
MutableArrayRef<VarDecl *> getCaseBodyVariablesBuffer() {
1233+
return {getTrailingObjects<VarDecl *>(),
1234+
static_cast<size_t>(Bits.CaseStmt.NumCaseBodyVars)};
1235+
}
1236+
12331237
public:
12341238
/// Create a parsed 'case'/'default' for 'switch' statement.
12351239
static CaseStmt *
@@ -1296,7 +1300,7 @@ class CaseStmt final
12961300
void setBody(BraceStmt *body) { BodyAndHasFallthrough.setPointer(body); }
12971301

12981302
/// True if the case block declares any patterns with local variable bindings.
1299-
bool hasCaseBodyVariables() const { return !CaseBodyVariables.empty(); }
1303+
bool hasCaseBodyVariables() const { return !getCaseBodyVariables().empty(); }
13001304

13011305
/// Get the source location of the 'case', 'default', or 'catch' of the first
13021306
/// label.
@@ -1349,7 +1353,7 @@ class CaseStmt final
13491353

13501354
/// Return an ArrayRef containing the case body variables of this CaseStmt.
13511355
ArrayRef<VarDecl *> getCaseBodyVariables() const {
1352-
return CaseBodyVariables;
1356+
return const_cast<CaseStmt *>(this)->getCaseBodyVariablesBuffer();
13531357
}
13541358

13551359
/// Find the next case statement within the same 'switch' or 'do-catch',

lib/AST/Stmt.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -758,9 +758,11 @@ CaseStmt::CaseStmt(CaseParentKind parentKind, SourceLoc itemIntroducerLoc,
758758
: Stmt(StmtKind::Case, getDefaultImplicitFlag(implicit, itemIntroducerLoc)),
759759
UnknownAttrLoc(unknownAttrLoc), ItemIntroducerLoc(itemIntroducerLoc),
760760
ItemTerminatorLoc(itemTerminatorLoc), ParentKind(parentKind),
761-
BodyAndHasFallthrough(body, fallthroughStmt.isNonNull()),
762-
CaseBodyVariables(caseBodyVariables) {
761+
BodyAndHasFallthrough(body, fallthroughStmt.isNonNull()) {
763762
Bits.CaseStmt.NumPatterns = caseLabelItems.size();
763+
Bits.CaseStmt.NumCaseBodyVars = caseBodyVariables.size();
764+
ASSERT(Bits.CaseStmt.NumCaseBodyVars == caseBodyVariables.size() &&
765+
"too many case body vars");
764766
assert(Bits.CaseStmt.NumPatterns > 0 &&
765767
"case block must have at least one pattern");
766768
assert(
@@ -770,6 +772,9 @@ CaseStmt::CaseStmt(CaseParentKind parentKind, SourceLoc itemIntroducerLoc,
770772
*getTrailingObjects<FallthroughStmt *>() = fallthroughStmt.get();
771773
}
772774

775+
std::uninitialized_copy(caseBodyVariables.begin(), caseBodyVariables.end(),
776+
getCaseBodyVariablesBuffer().begin());
777+
773778
MutableArrayRef<CaseLabelItem> items{getTrailingObjects<CaseLabelItem>(),
774779
static_cast<size_t>(Bits.CaseStmt.NumPatterns)};
775780

@@ -914,10 +919,11 @@ CaseStmt *CaseStmt::create(ASTContext &ctx, CaseParentKind ParentKind,
914919
BraceStmt *body, ArrayRef<VarDecl *> caseVarDecls,
915920
std::optional<bool> implicit,
916921
NullablePtr<FallthroughStmt> fallthroughStmt) {
917-
void *mem =
918-
ctx.Allocate(totalSizeToAlloc<FallthroughStmt *, CaseLabelItem>(
919-
fallthroughStmt.isNonNull(), caseLabelItems.size()),
920-
alignof(CaseStmt));
922+
void *mem = ctx.Allocate(
923+
totalSizeToAlloc<FallthroughStmt *, CaseLabelItem, VarDecl *>(
924+
fallthroughStmt.isNonNull(), caseLabelItems.size(),
925+
caseVarDecls.size()),
926+
alignof(CaseStmt));
921927
return ::new (mem)
922928
CaseStmt(ParentKind, caseLoc, caseLabelItems, unknownAttrLoc, colonLoc,
923929
body, caseVarDecls, implicit, fallthroughStmt);

0 commit comments

Comments
 (0)