Skip to content

Commit 195f42e

Browse files
committed
[capture-promotion] Change capture promotion to accumulate mutations/escape uses instead of just bailing.
In a subsequent commit, I am going to use this to emit a warning if this fails for captures by concurrent functions.
1 parent b16f340 commit 195f42e

File tree

1 file changed

+114
-57
lines changed

1 file changed

+114
-57
lines changed

lib/SILOptimizer/Mandatory/CapturePromotion.cpp

Lines changed: 114 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,10 @@
4343
//===----------------------------------------------------------------------===//
4444

4545
#define DEBUG_TYPE "sil-capture-promotion"
46+
4647
#include "swift/AST/GenericEnvironment.h"
4748
#include "swift/SIL/SILCloner.h"
49+
#include "swift/SIL/SILInstruction.h"
4850
#include "swift/SIL/TypeSubstCloner.h"
4951
#include "swift/SILOptimizer/PassManager/Passes.h"
5052
#include "swift/SILOptimizer/PassManager/Transforms.h"
@@ -756,8 +758,12 @@ void ClosureCloner::visitLoadInst(LoadInst *li) {
756758
namespace {
757759

758760
struct EscapeMutationScanningState {
759-
/// The list of mutations that we found while checking for escapes.
760-
llvm::SmallVector<SILInstruction *, 8> foundMutations;
761+
/// The list of mutations in the partial_apply caller that we found.
762+
SmallVector<Operand *, 8> accumulatedMutations;
763+
764+
/// The list of escapes in the partial_apply caller/callee of the box that we
765+
/// found.
766+
SmallVector<Operand *, 8> accumulatedEscapes;
761767

762768
/// A flag that we use to ensure that we only ever see 1 project_box on an
763769
/// alloc_box.
@@ -786,16 +792,20 @@ static bool isNonMutatingLoad(SILInstruction *inst) {
786792
return li->getOwnershipQualifier() != LoadOwnershipQualifier::Take;
787793
}
788794

789-
/// Given a partial_apply instruction and the argument index into its
790-
/// callee's argument list of a box argument (which is followed by an argument
791-
/// for the address of the box's contents), return true if the closure is known
792-
/// not to mutate the captured variable.
793-
static bool isNonMutatingCapture(SILArgument *boxArg) {
795+
/// Given a partial_apply instruction and the argument index into its callee's
796+
/// argument list of a box argument (which is followed by an argument for the
797+
/// address of the box's contents), return true if this box has mutating
798+
/// captures. Return false otherwise. All of the mutating captures that we find
799+
/// are placed into \p accumulatedMutatingUses.
800+
static bool getPartialApplyArgMutationsAndEscapes(
801+
SILArgument *boxArg, SmallVectorImpl<Operand *> &accumulatedMutatingUses,
802+
SmallVectorImpl<Operand *> &accumulatedEscapes) {
794803
SmallVector<ProjectBoxInst *, 2> projectBoxInsts;
795804

796805
// Conservatively do not allow any use of the box argument other than a
797806
// strong_release or projection, since this is the pattern expected from
798807
// SILGen.
808+
SmallVector<Operand *, 32> incrementalEscapes;
799809
for (auto *use : boxArg->getUses()) {
800810
if (isa<StrongReleaseInst>(use->getUser()) ||
801811
isa<DestroyValueInst>(use->getUser()))
@@ -806,7 +816,7 @@ static bool isNonMutatingCapture(SILArgument *boxArg) {
806816
continue;
807817
}
808818

809-
return false;
819+
incrementalEscapes.push_back(use);
810820
}
811821

812822
// Only allow loads of projections, either directly or via
@@ -815,33 +825,44 @@ static bool isNonMutatingCapture(SILArgument *boxArg) {
815825
// TODO: This seems overly limited. Why not projections of tuples and other
816826
// stuff? Also, why not recursive struct elements? This should be a helper
817827
// function that mirrors isNonEscapingUse.
818-
auto isAddrUseMutating = [](SILInstruction *addrInst) {
828+
auto checkIfAddrUseMutating = [&](Operand *addrUse) -> bool {
829+
unsigned initSize = incrementalEscapes.size();
830+
auto *addrInst = addrUse->getUser();
819831
if (auto *seai = dyn_cast<StructElementAddrInst>(addrInst)) {
820-
return all_of(seai->getUses(), [](Operand *op) -> bool {
821-
return isNonMutatingLoad(op->getUser());
822-
});
832+
for (auto *seaiUse : seai->getUses()) {
833+
if (!isNonMutatingLoad(seaiUse->getUser())) {
834+
incrementalEscapes.push_back(seaiUse);
835+
}
836+
}
837+
return incrementalEscapes.size() != initSize;
823838
}
824839

825-
return isNonMutatingLoad(addrInst) || isa<DebugValueAddrInst>(addrInst) ||
826-
isa<MarkFunctionEscapeInst>(addrInst) ||
827-
isa<EndAccessInst>(addrInst);
840+
if (isNonMutatingLoad(addrInst) || isa<DebugValueAddrInst>(addrInst) ||
841+
isa<MarkFunctionEscapeInst>(addrInst) || isa<EndAccessInst>(addrInst)) {
842+
return false;
843+
}
844+
845+
incrementalEscapes.push_back(addrUse);
846+
return true;
828847
};
829848

830849
for (auto *pbi : projectBoxInsts) {
831850
for (auto *use : pbi->getUses()) {
832851
if (auto *bai = dyn_cast<BeginAccessInst>(use->getUser())) {
833852
for (auto *accessUseOper : bai->getUses()) {
834-
if (!isAddrUseMutating(accessUseOper->getUser()))
835-
return false;
853+
checkIfAddrUseMutating(accessUseOper);
836854
}
837855
continue;
838856
}
839857

840-
if (!isAddrUseMutating(use->getUser()))
841-
return false;
858+
checkIfAddrUseMutating(use);
842859
}
843860
}
844861

862+
if (incrementalEscapes.empty())
863+
return false;
864+
while (!incrementalEscapes.empty())
865+
accumulatedEscapes.push_back(incrementalEscapes.pop_back_val());
845866
return true;
846867
}
847868

@@ -852,11 +873,14 @@ bool isPartialApplyNonEscapingUser(Operand *currentOp, PartialApplyInst *pai,
852873
unsigned opNo = currentOp->getOperandNumber();
853874
assert(opNo != 0 && "Alloc box used as callee of partial apply?");
854875

855-
// If we've already seen this partial apply, then it means the same alloc
856-
// box is being captured twice by the same closure, which is odd and
857-
// unexpected: bail instead of trying to handle this case.
876+
// If we've already seen this partial apply, then it means the same alloc box
877+
// is being captured twice by the same closure, which is odd and unexpected:
878+
// bail instead of trying to handle this case.
858879
if (state.globalIndexMap.count(pai)) {
880+
// TODO: Is it correct to treat this like an escape? We are just currently
881+
// flagging all failures as warnings.
859882
LLVM_DEBUG(llvm::dbgs() << " FAIL! Already seen.\n");
883+
state.accumulatedEscapes.push_back(currentOp);
860884
return false;
861885
}
862886

@@ -877,6 +901,7 @@ bool isPartialApplyNonEscapingUser(Operand *currentOp, PartialApplyInst *pai,
877901
if (!fn || !fn->isDefinition() || fn->isDynamicallyReplaceable()) {
878902
LLVM_DEBUG(llvm::dbgs() << " FAIL! Not a direct function definition "
879903
"reference.\n");
904+
state.accumulatedEscapes.push_back(currentOp);
880905
return false;
881906
}
882907

@@ -893,14 +918,17 @@ bool isPartialApplyNonEscapingUser(Operand *currentOp, PartialApplyInst *pai,
893918
.isAddressOnly(*f)) {
894919
LLVM_DEBUG(llvm::dbgs() << " FAIL! Box is an address only "
895920
"argument!\n");
921+
state.accumulatedEscapes.push_back(currentOp);
896922
return false;
897923
}
898924

899925
// Verify that this closure is known not to mutate the captured value; if
900926
// it does, then conservatively refuse to promote any captures of this
901927
// value.
902-
if (!isNonMutatingCapture(boxArg)) {
903-
LLVM_DEBUG(llvm::dbgs() << " FAIL: Have a mutating capture!\n");
928+
if (getPartialApplyArgMutationsAndEscapes(boxArg, state.accumulatedMutations,
929+
state.accumulatedEscapes)) {
930+
LLVM_DEBUG(llvm::dbgs() << " FAIL: Have a mutation or escape of a "
931+
"partial apply arg?!\n");
904932
return false;
905933
}
906934

@@ -920,15 +948,17 @@ namespace {
920948

921949
class NonEscapingUserVisitor
922950
: public SILInstructionVisitor<NonEscapingUserVisitor, bool> {
923-
llvm::SmallVector<Operand *, 32> worklist;
924-
llvm::SmallVectorImpl<SILInstruction *> &foundMutations;
951+
SmallVector<Operand *, 32> worklist;
952+
SmallVectorImpl<Operand *> &accumulatedMutations;
953+
SmallVectorImpl<Operand *> &accumulatedEscapes;
925954
NullablePtr<Operand> currentOp;
926955

927956
public:
928-
NonEscapingUserVisitor(
929-
Operand *initialOperand,
930-
llvm::SmallVectorImpl<SILInstruction *> &foundMutations)
931-
: worklist(), foundMutations(foundMutations), currentOp() {
957+
NonEscapingUserVisitor(Operand *initialOperand,
958+
SmallVectorImpl<Operand *> &accumulatedMutations,
959+
SmallVectorImpl<Operand *> &accumulatedEscapes)
960+
: worklist(), accumulatedMutations(accumulatedMutations),
961+
accumulatedEscapes(accumulatedEscapes), currentOp() {
932962
worklist.push_back(initialOperand);
933963
}
934964

@@ -937,6 +967,15 @@ class NonEscapingUserVisitor
937967
NonEscapingUserVisitor(NonEscapingUserVisitor &&) = delete;
938968
NonEscapingUserVisitor &operator=(NonEscapingUserVisitor &&) = delete;
939969

970+
private:
971+
void markCurrentOpAsMutation() {
972+
accumulatedMutations.push_back(currentOp.get());
973+
}
974+
void markCurrentOpAsEscape() {
975+
accumulatedEscapes.push_back(currentOp.get());
976+
}
977+
978+
public:
940979
bool compute() {
941980
while (!worklist.empty()) {
942981
currentOp = worklist.pop_back_val();
@@ -964,6 +1003,7 @@ class NonEscapingUserVisitor
9641003
bool visitSILInstruction(SILInstruction *inst) {
9651004
LLVM_DEBUG(llvm::dbgs()
9661005
<< " FAIL! Have unknown escaping user: " << *inst);
1006+
markCurrentOpAsEscape();
9671007
return false;
9681008
}
9691009

@@ -979,7 +1019,7 @@ class NonEscapingUserVisitor
9791019
#undef ALWAYS_NON_ESCAPING_INST
9801020

9811021
bool visitDeallocBoxInst(DeallocBoxInst *dbi) {
982-
foundMutations.push_back(dbi);
1022+
markCurrentOpAsMutation();
9831023
return true;
9841024
}
9851025

@@ -992,9 +1032,10 @@ class NonEscapingUserVisitor
9921032
if (!convention.isIndirectConvention()) {
9931033
LLVM_DEBUG(llvm::dbgs()
9941034
<< " FAIL! Found non indirect apply user: " << *ai);
1035+
markCurrentOpAsEscape();
9951036
return false;
9961037
}
997-
foundMutations.push_back(ai);
1038+
markCurrentOpAsMutation();
9981039
return true;
9991040
}
10001041

@@ -1017,7 +1058,7 @@ class NonEscapingUserVisitor
10171058
#define RECURSIVE_INST_VISITOR(MUTATING, INST) \
10181059
bool visit##INST##Inst(INST##Inst *i) { \
10191060
if (bool(detail::MUTATING)) { \
1020-
foundMutations.push_back(i); \
1061+
markCurrentOpAsMutation(); \
10211062
} \
10221063
addUsesToWorklist(i); \
10231064
return true; \
@@ -1044,25 +1085,27 @@ class NonEscapingUserVisitor
10441085

10451086
bool visitCopyAddrInst(CopyAddrInst *cai) {
10461087
if (currentOp.get()->getOperandNumber() == 1 || cai->isTakeOfSrc())
1047-
foundMutations.push_back(cai);
1088+
markCurrentOpAsMutation();
10481089
return true;
10491090
}
10501091

10511092
bool visitStoreInst(StoreInst *si) {
10521093
if (currentOp.get()->getOperandNumber() != 1) {
10531094
LLVM_DEBUG(llvm::dbgs() << " FAIL! Found store of pointer: " << *si);
1095+
markCurrentOpAsEscape();
10541096
return false;
10551097
}
1056-
foundMutations.push_back(si);
1098+
markCurrentOpAsMutation();
10571099
return true;
10581100
}
10591101

10601102
bool visitAssignInst(AssignInst *ai) {
10611103
if (currentOp.get()->getOperandNumber() != 1) {
10621104
LLVM_DEBUG(llvm::dbgs() << " FAIL! Found store of pointer: " << *ai);
1105+
markCurrentOpAsEscape();
10631106
return false;
10641107
}
1065-
foundMutations.push_back(ai);
1108+
markCurrentOpAsMutation();
10661109
return true;
10671110
}
10681111
};
@@ -1075,7 +1118,9 @@ class NonEscapingUserVisitor
10751118
/// the Mutations vector.
10761119
static bool isNonEscapingUse(Operand *initialOp,
10771120
EscapeMutationScanningState &state) {
1078-
return NonEscapingUserVisitor(initialOp, state.foundMutations).compute();
1121+
return NonEscapingUserVisitor(initialOp, state.accumulatedMutations,
1122+
state.accumulatedEscapes)
1123+
.compute();
10791124
}
10801125

10811126
static bool isProjectBoxNonEscapingUse(ProjectBoxInst *pbi,
@@ -1097,12 +1142,12 @@ static bool isProjectBoxNonEscapingUse(ProjectBoxInst *pbi,
10971142
// Top Level AllocBox Escape/Mutation Analysis
10981143
//===----------------------------------------------------------------------===//
10991144

1100-
static bool scanUsesForEscapesAndMutations(Operand *op,
1101-
EscapeMutationScanningState &state) {
1145+
static bool findEscapeOrMutationUses(Operand *op,
1146+
EscapeMutationScanningState &state) {
11021147
SILInstruction *user = op->getUser();
11031148

11041149
if (auto *pai = dyn_cast<PartialApplyInst>(user)) {
1105-
return isPartialApplyNonEscapingUser(op, pai, state);
1150+
return !isPartialApplyNonEscapingUser(op, pai, state);
11061151
}
11071152

11081153
// A mark_dependence user on a partial_apply is safe.
@@ -1112,7 +1157,11 @@ static bool scanUsesForEscapesAndMutations(Operand *op,
11121157
while ((mdi = dyn_cast<MarkDependenceInst>(parent))) {
11131158
parent = mdi->getValue();
11141159
}
1115-
return isa<PartialApplyInst>(parent);
1160+
if (isa<PartialApplyInst>(parent))
1161+
return false;
1162+
state.accumulatedEscapes.push_back(
1163+
&mdi->getOperandRef(MarkDependenceInst::Value));
1164+
return true;
11161165
}
11171166
}
11181167

@@ -1121,9 +1170,9 @@ static bool scanUsesForEscapesAndMutations(Operand *op,
11211170
// can be seen since there is no code for reasoning about multiple
11221171
// boxes. Just put in the restriction so we are consistent.
11231172
if (state.sawProjectBoxInst)
1124-
return false;
1173+
return true;
11251174
state.sawProjectBoxInst = true;
1126-
return isProjectBoxNonEscapingUse(pbi, state);
1175+
return !isProjectBoxNonEscapingUse(pbi, state);
11271176
}
11281177

11291178
// Given a top level copy value use or mark_uninitialized, check all of its
@@ -1134,10 +1183,11 @@ static bool scanUsesForEscapesAndMutations(Operand *op,
11341183
// derived from a projection like instruction). In fact such a thing may not
11351184
// even make any sense!
11361185
if (isa<CopyValueInst>(user) || isa<MarkUninitializedInst>(user)) {
1137-
return all_of(cast<SingleValueInstruction>(user)->getUses(),
1138-
[&state](Operand *userOp) -> bool {
1139-
return scanUsesForEscapesAndMutations(userOp, state);
1140-
});
1186+
bool foundSomeMutations = false;
1187+
for (auto *use : cast<SingleValueInstruction>(user)->getUses()) {
1188+
foundSomeMutations |= findEscapeOrMutationUses(use, state);
1189+
}
1190+
return foundSomeMutations;
11411191
}
11421192

11431193
// Verify that this use does not otherwise allow the alloc_box to
@@ -1153,14 +1203,20 @@ static bool
11531203
examineAllocBoxInst(AllocBoxInst *abi, ReachabilityInfo &ri,
11541204
llvm::DenseMap<PartialApplyInst *, unsigned> &im) {
11551205
LLVM_DEBUG(llvm::dbgs() << "Visiting alloc box: " << *abi);
1156-
EscapeMutationScanningState state{{}, false, im};
1206+
EscapeMutationScanningState state{{}, {}, false, im};
11571207

1158-
// Scan the box for interesting uses.
1159-
if (any_of(abi->getUses(), [&state](Operand *op) {
1160-
return !scanUsesForEscapesAndMutations(op, state);
1161-
})) {
1208+
// Scan the box for escaping or mutating uses.
1209+
for (auto *use : abi->getUses()) {
1210+
findEscapeOrMutationUses(use, state);
1211+
}
1212+
1213+
if (!state.accumulatedEscapes.empty()) {
11621214
LLVM_DEBUG(llvm::dbgs()
1163-
<< "Found an escaping use! Can not optimize this alloc box?!\n");
1215+
<< "Found escaping uses! Can not optimize this alloc box?!\n");
1216+
while (!state.accumulatedEscapes.empty()) {
1217+
auto *escapingUse = state.accumulatedEscapes.pop_back_val();
1218+
LLVM_DEBUG(llvm::dbgs() << "Escaping use: " << *escapingUse->getUser());
1219+
}
11641220
return false;
11651221
}
11661222

@@ -1183,17 +1239,18 @@ examineAllocBoxInst(AllocBoxInst *abi, ReachabilityInfo &ri,
11831239
LLVM_DEBUG(llvm::dbgs()
11841240
<< "Checking for any mutations that invalidate captures...\n");
11851241
// Loop over all mutations to possibly invalidate captures.
1186-
for (auto *inst : state.foundMutations) {
1242+
for (auto *use : state.accumulatedMutations) {
11871243
auto iter = im.begin();
11881244
while (iter != im.end()) {
1245+
auto *user = use->getUser();
11891246
auto *pai = iter->first;
11901247
// The mutation invalidates a capture if it occurs in a block reachable
11911248
// from the block the partial_apply is in, or if it is in the same
11921249
// block is after the partial_apply.
1193-
if (ri.isReachable(pai->getParent(), inst->getParent()) ||
1194-
(pai->getParent() == inst->getParent() && isAfter(pai, inst))) {
1250+
if (ri.isReachable(pai->getParent(), user->getParent()) ||
1251+
(pai->getParent() == user->getParent() && isAfter(pai, user))) {
11951252
LLVM_DEBUG(llvm::dbgs() << " Invalidating: " << *pai);
1196-
LLVM_DEBUG(llvm::dbgs() << " Because of user: " << *inst);
1253+
LLVM_DEBUG(llvm::dbgs() << " Because of user: " << *user);
11971254
auto prev = iter++;
11981255
im.erase(prev);
11991256
continue;

0 commit comments

Comments
 (0)