Skip to content

Commit 22a8b28

Browse files
authored
Merge pull request #83009 from jckarter/case-binding-addressable-scopes
SILGen: Properly set up addressability scopes for case pattern bindings.
2 parents 439afd6 + c24801a commit 22a8b28

File tree

5 files changed

+126
-31
lines changed

5 files changed

+126
-31
lines changed

lib/SILGen/SILGenDecl.cpp

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "swift/SIL/SILType.h"
3737
#include "swift/SIL/TypeLowering.h"
3838
#include "llvm/ADT/SmallString.h"
39+
#include "llvm/Support/ErrorHandling.h"
3940
#include <iterator>
4041

4142
using namespace swift;
@@ -699,20 +700,23 @@ class DeallocateLocalVariableAddressableBuffer : public Cleanup {
699700

700701
void emit(SILGenFunction &SGF, CleanupLocation l,
701702
ForUnwind_t forUnwind) override {
703+
auto addressableBuffer = SGF.getAddressableBufferInfo(vd);
704+
if (!addressableBuffer) {
705+
return;
706+
}
702707
auto found = SGF.VarLocs.find(vd);
703708
if (found == SGF.VarLocs.end()) {
704709
return;
705710
}
706-
auto &loc = found->second;
707711

708-
if (auto &state = loc.addressableBuffer.state) {
712+
if (auto *state = addressableBuffer->getState()) {
709713
// The addressable buffer was forced, so clean it up now.
710714
deallocateAddressable(SGF, l, *state);
711715
} else {
712716
// Remember this insert location in case we need to force the addressable
713717
// buffer later.
714718
SILInstruction *marker = SGF.B.createTuple(l, {});
715-
loc.addressableBuffer.cleanupPoints.emplace_back(marker);
719+
addressableBuffer->cleanupPoints.emplace_back(marker);
716720
}
717721
}
718722

@@ -2254,7 +2258,7 @@ SILGenFunction::getLocalVariableAddressableBuffer(VarDecl *decl,
22542258

22552259
auto value = foundVarLoc->second.value;
22562260
auto access = foundVarLoc->second.access;
2257-
auto *state = foundVarLoc->second.addressableBuffer.state.get();
2261+
auto *state = getAddressableBufferInfo(decl)->getState();
22582262

22592263
SILType fullyAbstractedTy = getLoweredType(AbstractionPattern::getOpaque(),
22602264
decl->getTypeInContext()->getRValueType());
@@ -2292,9 +2296,26 @@ SILGenFunction::getLocalVariableAddressableBuffer(VarDecl *decl,
22922296
SILValue reabstraction, allocStack, storeBorrow;
22932297
{
22942298
SavedInsertionPointRAII save(B);
2295-
ASSERT(AddressableBuffers.find(decl) != AddressableBuffers.end()
2296-
&& "local variable did not have an addressability scope set");
2297-
auto insertPoint = AddressableBuffers[decl].insertPoint;
2299+
SILInstruction *insertPoint = nullptr;
2300+
// Look through bindings that might alias the original addressable buffer
2301+
// (such as case block variables, which use an alias variable to represent the
2302+
// incoming value from all of the case label patterns).
2303+
VarDecl *origDecl = decl;
2304+
do {
2305+
auto bufferIter = AddressableBuffers.find(origDecl);
2306+
ASSERT(bufferIter != AddressableBuffers.end()
2307+
&& "local variable didn't have an addressability scope set");
2308+
2309+
insertPoint = bufferIter->second.getInsertPoint();
2310+
if (insertPoint) {
2311+
break;
2312+
}
2313+
2314+
origDecl = bufferIter->second.getOriginalForAlias();
2315+
ASSERT(origDecl && "no insert point or alias for addressable declaration!");
2316+
} while (true);
2317+
2318+
assert(insertPoint && "didn't find an insertion point for the addressable buffer");
22982319
B.setInsertionPoint(insertPoint);
22992320
auto allocStackTy = fullyAbstractedTy;
23002321
if (value->getType().isMoveOnlyWrapped()) {
@@ -2313,8 +2334,12 @@ SILGenFunction::getLocalVariableAddressableBuffer(VarDecl *decl,
23132334
SavedInsertionPointRAII save(B);
23142335
if (isa<ParamDecl>(decl)) {
23152336
B.setInsertionPoint(allocStack->getNextInstruction());
2337+
} else if (auto inst = value->getDefiningInstruction()) {
2338+
B.setInsertionPoint(inst->getParent(), std::next(inst->getIterator()));
2339+
} else if (auto arg = dyn_cast<SILArgument>(value)) {
2340+
B.setInsertionPoint(arg->getParent()->begin());
23162341
} else {
2317-
B.setInsertionPoint(value->getNextInstruction());
2342+
llvm_unreachable("unexpected value source!");
23182343
}
23192344
auto declarationLoc = value->getDefiningInsertionPoint()->getLoc();
23202345

@@ -2334,17 +2359,15 @@ SILGenFunction::getLocalVariableAddressableBuffer(VarDecl *decl,
23342359
}
23352360

23362361
// Record the addressable representation.
2337-
auto &addressableBuffer = VarLocs[decl].addressableBuffer;
2338-
addressableBuffer.state
2339-
= std::make_unique<VarLoc::AddressableBuffer::State>(reabstraction,
2340-
allocStack,
2341-
storeBorrow);
2342-
auto *newState = addressableBuffer.state.get();
2362+
auto *addressableBuffer = getAddressableBufferInfo(decl);
2363+
auto *newState
2364+
= new VarLoc::AddressableBuffer::State(reabstraction, allocStack, storeBorrow);
2365+
addressableBuffer->stateOrAlias = newState;
23432366

23442367
// Emit cleanups on any paths where we previously would have cleaned up
23452368
// the addressable representation if it had been forced earlier.
2346-
decltype(addressableBuffer.cleanupPoints) cleanupPoints;
2347-
cleanupPoints.swap(addressableBuffer.cleanupPoints);
2369+
decltype(addressableBuffer->cleanupPoints) cleanupPoints;
2370+
cleanupPoints.swap(addressableBuffer->cleanupPoints);
23482371

23492372
for (SILInstruction *cleanupPoint : cleanupPoints) {
23502373
SavedInsertionPointRAII insertCleanup(B, cleanupPoint);
@@ -2391,4 +2414,7 @@ SILGenFunction::VarLoc::AddressableBuffer::~AddressableBuffer() {
23912414
for (auto cleanupPoint : cleanupPoints) {
23922415
cleanupPoint->eraseFromParent();
23932416
}
2417+
if (auto state = stateOrAlias.dyn_cast<State*>()) {
2418+
delete state;
2419+
}
23942420
}

lib/SILGen/SILGenFunction.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2004,3 +2004,20 @@ void SILGenFunction::emitAssignOrInit(SILLocation loc, ManagedValue selfValue,
20042004
newValue.forward(*this), initFRef, setterFRef,
20052005
AssignOrInitInst::Unknown);
20062006
}
2007+
2008+
SILGenFunction::VarLoc::AddressableBuffer *
2009+
SILGenFunction::getAddressableBufferInfo(ValueDecl *vd) {
2010+
do {
2011+
auto found = VarLocs.find(vd);
2012+
if (found == VarLocs.end()) {
2013+
return nullptr;
2014+
}
2015+
2016+
if (auto orig = found->second.addressableBuffer.stateOrAlias
2017+
.dyn_cast<VarDecl*>()) {
2018+
vd = orig;
2019+
continue;
2020+
}
2021+
return &found->second.addressableBuffer;
2022+
} while (true);
2023+
}

lib/SILGen/SILGenFunction.h

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "swift/Basic/ProfileCounter.h"
2828
#include "swift/Basic/Statistic.h"
2929
#include "swift/SIL/SILBuilder.h"
30+
#include "swift/SIL/SILInstruction.h"
3031
#include "swift/SIL/SILType.h"
3132
#include "llvm/ADT/PointerIntPair.h"
3233

@@ -497,27 +498,42 @@ class LLVM_LIBRARY_VISIBILITY SILGenFunction
497498
{}
498499
};
499500

500-
std::unique_ptr<State> state = nullptr;
501-
501+
llvm::PointerUnion<State *, VarDecl*> stateOrAlias = (State*)nullptr;
502+
502503
// If the variable cleanup is triggered before the addressable
503504
// representation is demanded, but the addressable representation
504505
// gets demanded later, we save the insertion points where the
505506
// representation would be cleaned up so we can backfill them.
506507
llvm::SmallVector<SILInstruction*, 1> cleanupPoints;
507508

508509
AddressableBuffer() = default;
510+
511+
AddressableBuffer(VarDecl *original)
512+
: stateOrAlias(original)
513+
{
514+
}
509515

510516
AddressableBuffer(AddressableBuffer &&other)
511-
: state(std::move(other.state))
517+
: stateOrAlias(other.stateOrAlias)
512518
{
519+
other.stateOrAlias = (State*)nullptr;
513520
cleanupPoints.swap(other.cleanupPoints);
514521
}
515522

516523
AddressableBuffer &operator=(AddressableBuffer &&other) {
517-
state = std::move(other.state);
524+
if (auto state = stateOrAlias.dyn_cast<State*>()) {
525+
delete state;
526+
}
527+
stateOrAlias = other.stateOrAlias;
518528
cleanupPoints.swap(other.cleanupPoints);
519529
return *this;
520530
}
531+
532+
State *getState() {
533+
ASSERT(!stateOrAlias.is<VarDecl*>()
534+
&& "must get state from original AddressableBuffer");
535+
return stateOrAlias.dyn_cast<State*>();
536+
}
521537

522538
~AddressableBuffer();
523539
};
@@ -535,33 +551,50 @@ class LLVM_LIBRARY_VISIBILITY SILGenFunction
535551
/// emitted. The map is queried to produce the lvalue for a DeclRefExpr to
536552
/// a local variable.
537553
llvm::DenseMap<ValueDecl*, VarLoc> VarLocs;
554+
555+
VarLoc::AddressableBuffer *getAddressableBufferInfo(ValueDecl *vd);
538556

539557
// Represents an addressable buffer that has been allocated but not yet used.
540558
struct PreparedAddressableBuffer {
541-
SILInstruction *insertPoint = nullptr;
559+
llvm::PointerUnion<SILInstruction *, VarDecl *> insertPointOrAlias
560+
= (SILInstruction*)nullptr;
542561

543562
PreparedAddressableBuffer() = default;
544563

545564
PreparedAddressableBuffer(SILInstruction *insertPoint)
546-
: insertPoint(insertPoint)
565+
: insertPointOrAlias(insertPoint)
547566
{
548567
ASSERT(insertPoint && "null insertion point provided");
549568
}
569+
570+
PreparedAddressableBuffer(VarDecl *alias)
571+
: insertPointOrAlias(alias)
572+
{
573+
ASSERT(alias && "null alias provided");
574+
}
550575

551576
PreparedAddressableBuffer(PreparedAddressableBuffer &&other)
552-
: insertPoint(other.insertPoint)
577+
: insertPointOrAlias(other.insertPointOrAlias)
553578
{
554-
other.insertPoint = nullptr;
579+
other.insertPointOrAlias = (SILInstruction*)nullptr;
555580
}
556581

557582
PreparedAddressableBuffer &operator=(PreparedAddressableBuffer &&other) {
558-
insertPoint = other.insertPoint;
559-
other.insertPoint = nullptr;
583+
insertPointOrAlias = other.insertPointOrAlias;
584+
other.insertPointOrAlias = nullptr;
560585
return *this;
561586
}
587+
588+
SILInstruction *getInsertPoint() const {
589+
return insertPointOrAlias.dyn_cast<SILInstruction*>();
590+
}
562591

592+
VarDecl *getOriginalForAlias() const {
593+
return insertPointOrAlias.dyn_cast<VarDecl*>();
594+
}
595+
563596
~PreparedAddressableBuffer() {
564-
if (insertPoint) {
597+
if (auto insertPoint = getInsertPoint()) {
565598
// Remove the insertion point if it went unused.
566599
insertPoint->eraseFromParent();
567600
}

lib/SILGen/SILGenPattern.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
#include "swift/AST/SILOptions.h"
2525
#include "swift/AST/SubstitutionMap.h"
2626
#include "swift/AST/Types.h"
27-
#include "swift/Basic/Assertions.h"
2827
#include "swift/Basic/Defer.h"
2928
#include "swift/Basic/ProfileCounter.h"
3029
#include "swift/Basic/STLExtras.h"
@@ -241,7 +240,7 @@ static bool isWildcardPattern(const Pattern *p) {
241240

242241
/// Check to see if the given pattern is a specializing pattern,
243242
/// and return a semantic pattern for it.
244-
Pattern *getSpecializingPattern(Pattern *p) {
243+
static Pattern *getSpecializingPattern(Pattern *p) {
245244
// Empty entries are basically AnyPatterns.
246245
if (!p) return nullptr;
247246

@@ -975,9 +974,8 @@ class SpecializedArgForwarder : private ArgForwarderBase {
975974
if (IsFinalUse) {
976975
ArgForwarderBase::forwardIntoIrrefutable(value);
977976
return value;
978-
} else {
979-
return ArgForwarderBase::forward(value, loc);
980977
}
978+
return ArgForwarderBase::forward(value, loc);
981979
}
982980
};
983981

@@ -3175,6 +3173,9 @@ static void switchCaseStmtSuccessCallback(SILGenFunction &SGF,
31753173
expectedLoc = SILGenFunction::VarLoc(vdLoc->second.value,
31763174
vdLoc->second.access,
31773175
vdLoc->second.box);
3176+
expectedLoc.addressableBuffer = vd;
3177+
// Alias the addressable buffer for the two variables.
3178+
SGF.AddressableBuffers[expected] = vd;
31783179

31793180
// Emit a debug description for the variable, nested within a scope
31803181
// for the pattern match.

test/SILGen/addressable_params.swift

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,21 @@ struct Foo {
4646
}
4747
}
4848
}
49+
50+
enum TestEnum {
51+
case foo(String)
52+
case bar(String)
53+
}
54+
55+
func addressableParam(_: @_addressable String) -> Bool { true }
56+
57+
func testAddressableSwitchBinding(e: TestEnum) -> Bool {
58+
return switch e {
59+
case .foo(let f) where addressableParam(f):
60+
true
61+
case .bar(let b):
62+
addressableParam(b)
63+
default:
64+
false
65+
}
66+
}

0 commit comments

Comments
 (0)