43
43
// ===----------------------------------------------------------------------===//
44
44
45
45
#define DEBUG_TYPE " sil-capture-promotion"
46
+
46
47
#include " swift/AST/GenericEnvironment.h"
47
48
#include " swift/SIL/SILCloner.h"
49
+ #include " swift/SIL/SILInstruction.h"
48
50
#include " swift/SIL/TypeSubstCloner.h"
49
51
#include " swift/SILOptimizer/PassManager/Passes.h"
50
52
#include " swift/SILOptimizer/PassManager/Transforms.h"
@@ -756,8 +758,12 @@ void ClosureCloner::visitLoadInst(LoadInst *li) {
756
758
namespace {
757
759
758
760
struct EscapeMutationScanningState {
759
- // / The list of mutations that we found while checking for escapes.
760
- llvm::SmallVector<SILInstruction *, 8 > foundMutations;
761
+ // / The list of mutations in the partial_apply caller that we found.
762
+ SmallVector<Operand *, 8 > accumulatedMutations;
763
+
764
+ // / The list of escapes in the partial_apply caller/callee of the box that we
765
+ // / found.
766
+ SmallVector<Operand *, 8 > accumulatedEscapes;
761
767
762
768
// / A flag that we use to ensure that we only ever see 1 project_box on an
763
769
// / alloc_box.
@@ -786,16 +792,20 @@ static bool isNonMutatingLoad(SILInstruction *inst) {
786
792
return li->getOwnershipQualifier () != LoadOwnershipQualifier::Take;
787
793
}
788
794
789
- // / Given a partial_apply instruction and the argument index into its
790
- // / callee's argument list of a box argument (which is followed by an argument
791
- // / for the address of the box's contents), return true if the closure is known
792
- // / not to mutate the captured variable.
793
- static bool isNonMutatingCapture (SILArgument *boxArg) {
795
+ // / Given a partial_apply instruction and the argument index into its callee's
796
+ // / argument list of a box argument (which is followed by an argument for the
797
+ // / address of the box's contents), return true if this box has mutating
798
+ // / captures. Return false otherwise. All of the mutating captures that we find
799
+ // / are placed into \p accumulatedMutatingUses.
800
+ static bool getPartialApplyArgMutationsAndEscapes (
801
+ SILArgument *boxArg, SmallVectorImpl<Operand *> &accumulatedMutatingUses,
802
+ SmallVectorImpl<Operand *> &accumulatedEscapes) {
794
803
SmallVector<ProjectBoxInst *, 2 > projectBoxInsts;
795
804
796
805
// Conservatively do not allow any use of the box argument other than a
797
806
// strong_release or projection, since this is the pattern expected from
798
807
// SILGen.
808
+ SmallVector<Operand *, 32 > incrementalEscapes;
799
809
for (auto *use : boxArg->getUses ()) {
800
810
if (isa<StrongReleaseInst>(use->getUser ()) ||
801
811
isa<DestroyValueInst>(use->getUser ()))
@@ -806,7 +816,7 @@ static bool isNonMutatingCapture(SILArgument *boxArg) {
806
816
continue ;
807
817
}
808
818
809
- return false ;
819
+ incrementalEscapes. push_back (use) ;
810
820
}
811
821
812
822
// Only allow loads of projections, either directly or via
@@ -815,33 +825,44 @@ static bool isNonMutatingCapture(SILArgument *boxArg) {
815
825
// TODO: This seems overly limited. Why not projections of tuples and other
816
826
// stuff? Also, why not recursive struct elements? This should be a helper
817
827
// function that mirrors isNonEscapingUse.
818
- auto isAddrUseMutating = [](SILInstruction *addrInst) {
828
+ auto checkIfAddrUseMutating = [&](Operand *addrUse) -> bool {
829
+ unsigned initSize = incrementalEscapes.size ();
830
+ auto *addrInst = addrUse->getUser ();
819
831
if (auto *seai = dyn_cast<StructElementAddrInst>(addrInst)) {
820
- return all_of (seai->getUses (), [](Operand *op) -> bool {
821
- return isNonMutatingLoad (op->getUser ());
822
- });
832
+ for (auto *seaiUse : seai->getUses ()) {
833
+ if (!isNonMutatingLoad (seaiUse->getUser ())) {
834
+ incrementalEscapes.push_back (seaiUse);
835
+ }
836
+ }
837
+ return incrementalEscapes.size () != initSize;
823
838
}
824
839
825
- return isNonMutatingLoad (addrInst) || isa<DebugValueAddrInst>(addrInst) ||
826
- isa<MarkFunctionEscapeInst>(addrInst) ||
827
- isa<EndAccessInst>(addrInst);
840
+ if (isNonMutatingLoad (addrInst) || isa<DebugValueAddrInst>(addrInst) ||
841
+ isa<MarkFunctionEscapeInst>(addrInst) || isa<EndAccessInst>(addrInst)) {
842
+ return false ;
843
+ }
844
+
845
+ incrementalEscapes.push_back (addrUse);
846
+ return true ;
828
847
};
829
848
830
849
for (auto *pbi : projectBoxInsts) {
831
850
for (auto *use : pbi->getUses ()) {
832
851
if (auto *bai = dyn_cast<BeginAccessInst>(use->getUser ())) {
833
852
for (auto *accessUseOper : bai->getUses ()) {
834
- if (!isAddrUseMutating (accessUseOper->getUser ()))
835
- return false ;
853
+ checkIfAddrUseMutating (accessUseOper);
836
854
}
837
855
continue ;
838
856
}
839
857
840
- if (!isAddrUseMutating (use->getUser ()))
841
- return false ;
858
+ checkIfAddrUseMutating (use);
842
859
}
843
860
}
844
861
862
+ if (incrementalEscapes.empty ())
863
+ return false ;
864
+ while (!incrementalEscapes.empty ())
865
+ accumulatedEscapes.push_back (incrementalEscapes.pop_back_val ());
845
866
return true ;
846
867
}
847
868
@@ -852,11 +873,14 @@ bool isPartialApplyNonEscapingUser(Operand *currentOp, PartialApplyInst *pai,
852
873
unsigned opNo = currentOp->getOperandNumber ();
853
874
assert (opNo != 0 && " Alloc box used as callee of partial apply?" );
854
875
855
- // If we've already seen this partial apply, then it means the same alloc
856
- // box is being captured twice by the same closure, which is odd and
857
- // unexpected: bail instead of trying to handle this case.
876
+ // If we've already seen this partial apply, then it means the same alloc box
877
+ // is being captured twice by the same closure, which is odd and unexpected:
878
+ // bail instead of trying to handle this case.
858
879
if (state.globalIndexMap .count (pai)) {
880
+ // TODO: Is it correct to treat this like an escape? We are just currently
881
+ // flagging all failures as warnings.
859
882
LLVM_DEBUG (llvm::dbgs () << " FAIL! Already seen.\n " );
883
+ state.accumulatedEscapes .push_back (currentOp);
860
884
return false ;
861
885
}
862
886
@@ -877,6 +901,7 @@ bool isPartialApplyNonEscapingUser(Operand *currentOp, PartialApplyInst *pai,
877
901
if (!fn || !fn->isDefinition () || fn->isDynamicallyReplaceable ()) {
878
902
LLVM_DEBUG (llvm::dbgs () << " FAIL! Not a direct function definition "
879
903
" reference.\n " );
904
+ state.accumulatedEscapes .push_back (currentOp);
880
905
return false ;
881
906
}
882
907
@@ -893,14 +918,17 @@ bool isPartialApplyNonEscapingUser(Operand *currentOp, PartialApplyInst *pai,
893
918
.isAddressOnly (*f)) {
894
919
LLVM_DEBUG (llvm::dbgs () << " FAIL! Box is an address only "
895
920
" argument!\n " );
921
+ state.accumulatedEscapes .push_back (currentOp);
896
922
return false ;
897
923
}
898
924
899
925
// Verify that this closure is known not to mutate the captured value; if
900
926
// it does, then conservatively refuse to promote any captures of this
901
927
// value.
902
- if (!isNonMutatingCapture (boxArg)) {
903
- LLVM_DEBUG (llvm::dbgs () << " FAIL: Have a mutating capture!\n " );
928
+ if (getPartialApplyArgMutationsAndEscapes (boxArg, state.accumulatedMutations ,
929
+ state.accumulatedEscapes )) {
930
+ LLVM_DEBUG (llvm::dbgs () << " FAIL: Have a mutation or escape of a "
931
+ " partial apply arg?!\n " );
904
932
return false ;
905
933
}
906
934
@@ -920,15 +948,17 @@ namespace {
920
948
921
949
class NonEscapingUserVisitor
922
950
: public SILInstructionVisitor<NonEscapingUserVisitor, bool > {
923
- llvm::SmallVector<Operand *, 32 > worklist;
924
- llvm::SmallVectorImpl<SILInstruction *> &foundMutations;
951
+ SmallVector<Operand *, 32 > worklist;
952
+ SmallVectorImpl<Operand *> &accumulatedMutations;
953
+ SmallVectorImpl<Operand *> &accumulatedEscapes;
925
954
NullablePtr<Operand> currentOp;
926
955
927
956
public:
928
- NonEscapingUserVisitor (
929
- Operand *initialOperand,
930
- llvm::SmallVectorImpl<SILInstruction *> &foundMutations)
931
- : worklist(), foundMutations(foundMutations), currentOp() {
957
+ NonEscapingUserVisitor (Operand *initialOperand,
958
+ SmallVectorImpl<Operand *> &accumulatedMutations,
959
+ SmallVectorImpl<Operand *> &accumulatedEscapes)
960
+ : worklist(), accumulatedMutations(accumulatedMutations),
961
+ accumulatedEscapes (accumulatedEscapes), currentOp() {
932
962
worklist.push_back (initialOperand);
933
963
}
934
964
@@ -937,6 +967,15 @@ class NonEscapingUserVisitor
937
967
NonEscapingUserVisitor (NonEscapingUserVisitor &&) = delete;
938
968
NonEscapingUserVisitor &operator =(NonEscapingUserVisitor &&) = delete ;
939
969
970
+ private:
971
+ void markCurrentOpAsMutation () {
972
+ accumulatedMutations.push_back (currentOp.get ());
973
+ }
974
+ void markCurrentOpAsEscape () {
975
+ accumulatedEscapes.push_back (currentOp.get ());
976
+ }
977
+
978
+ public:
940
979
bool compute () {
941
980
while (!worklist.empty ()) {
942
981
currentOp = worklist.pop_back_val ();
@@ -964,6 +1003,7 @@ class NonEscapingUserVisitor
964
1003
bool visitSILInstruction (SILInstruction *inst) {
965
1004
LLVM_DEBUG (llvm::dbgs ()
966
1005
<< " FAIL! Have unknown escaping user: " << *inst);
1006
+ markCurrentOpAsEscape ();
967
1007
return false ;
968
1008
}
969
1009
@@ -979,7 +1019,7 @@ class NonEscapingUserVisitor
979
1019
#undef ALWAYS_NON_ESCAPING_INST
980
1020
981
1021
bool visitDeallocBoxInst (DeallocBoxInst *dbi) {
982
- foundMutations. push_back (dbi );
1022
+ markCurrentOpAsMutation ( );
983
1023
return true ;
984
1024
}
985
1025
@@ -992,9 +1032,10 @@ class NonEscapingUserVisitor
992
1032
if (!convention.isIndirectConvention ()) {
993
1033
LLVM_DEBUG (llvm::dbgs ()
994
1034
<< " FAIL! Found non indirect apply user: " << *ai);
1035
+ markCurrentOpAsEscape ();
995
1036
return false ;
996
1037
}
997
- foundMutations. push_back (ai );
1038
+ markCurrentOpAsMutation ( );
998
1039
return true ;
999
1040
}
1000
1041
@@ -1017,7 +1058,7 @@ class NonEscapingUserVisitor
1017
1058
#define RECURSIVE_INST_VISITOR (MUTATING, INST ) \
1018
1059
bool visit##INST##Inst(INST##Inst *i) { \
1019
1060
if (bool (detail::MUTATING)) { \
1020
- foundMutations. push_back (i); \
1061
+ markCurrentOpAsMutation (); \
1021
1062
} \
1022
1063
addUsesToWorklist (i); \
1023
1064
return true ; \
@@ -1044,25 +1085,27 @@ class NonEscapingUserVisitor
1044
1085
1045
1086
bool visitCopyAddrInst (CopyAddrInst *cai) {
1046
1087
if (currentOp.get ()->getOperandNumber () == 1 || cai->isTakeOfSrc ())
1047
- foundMutations. push_back (cai );
1088
+ markCurrentOpAsMutation ( );
1048
1089
return true ;
1049
1090
}
1050
1091
1051
1092
bool visitStoreInst (StoreInst *si) {
1052
1093
if (currentOp.get ()->getOperandNumber () != 1 ) {
1053
1094
LLVM_DEBUG (llvm::dbgs () << " FAIL! Found store of pointer: " << *si);
1095
+ markCurrentOpAsEscape ();
1054
1096
return false ;
1055
1097
}
1056
- foundMutations. push_back (si );
1098
+ markCurrentOpAsMutation ( );
1057
1099
return true ;
1058
1100
}
1059
1101
1060
1102
bool visitAssignInst (AssignInst *ai) {
1061
1103
if (currentOp.get ()->getOperandNumber () != 1 ) {
1062
1104
LLVM_DEBUG (llvm::dbgs () << " FAIL! Found store of pointer: " << *ai);
1105
+ markCurrentOpAsEscape ();
1063
1106
return false ;
1064
1107
}
1065
- foundMutations. push_back (ai );
1108
+ markCurrentOpAsMutation ( );
1066
1109
return true ;
1067
1110
}
1068
1111
};
@@ -1075,7 +1118,9 @@ class NonEscapingUserVisitor
1075
1118
// / the Mutations vector.
1076
1119
static bool isNonEscapingUse (Operand *initialOp,
1077
1120
EscapeMutationScanningState &state) {
1078
- return NonEscapingUserVisitor (initialOp, state.foundMutations ).compute ();
1121
+ return NonEscapingUserVisitor (initialOp, state.accumulatedMutations ,
1122
+ state.accumulatedEscapes )
1123
+ .compute ();
1079
1124
}
1080
1125
1081
1126
static bool isProjectBoxNonEscapingUse (ProjectBoxInst *pbi,
@@ -1097,12 +1142,12 @@ static bool isProjectBoxNonEscapingUse(ProjectBoxInst *pbi,
1097
1142
// Top Level AllocBox Escape/Mutation Analysis
1098
1143
// ===----------------------------------------------------------------------===//
1099
1144
1100
- static bool scanUsesForEscapesAndMutations (Operand *op,
1101
- EscapeMutationScanningState &state) {
1145
+ static bool findEscapeOrMutationUses (Operand *op,
1146
+ EscapeMutationScanningState &state) {
1102
1147
SILInstruction *user = op->getUser ();
1103
1148
1104
1149
if (auto *pai = dyn_cast<PartialApplyInst>(user)) {
1105
- return isPartialApplyNonEscapingUser (op, pai, state);
1150
+ return ! isPartialApplyNonEscapingUser (op, pai, state);
1106
1151
}
1107
1152
1108
1153
// A mark_dependence user on a partial_apply is safe.
@@ -1112,7 +1157,11 @@ static bool scanUsesForEscapesAndMutations(Operand *op,
1112
1157
while ((mdi = dyn_cast<MarkDependenceInst>(parent))) {
1113
1158
parent = mdi->getValue ();
1114
1159
}
1115
- return isa<PartialApplyInst>(parent);
1160
+ if (isa<PartialApplyInst>(parent))
1161
+ return false ;
1162
+ state.accumulatedEscapes .push_back (
1163
+ &mdi->getOperandRef (MarkDependenceInst::Value));
1164
+ return true ;
1116
1165
}
1117
1166
}
1118
1167
@@ -1121,9 +1170,9 @@ static bool scanUsesForEscapesAndMutations(Operand *op,
1121
1170
// can be seen since there is no code for reasoning about multiple
1122
1171
// boxes. Just put in the restriction so we are consistent.
1123
1172
if (state.sawProjectBoxInst )
1124
- return false ;
1173
+ return true ;
1125
1174
state.sawProjectBoxInst = true ;
1126
- return isProjectBoxNonEscapingUse (pbi, state);
1175
+ return ! isProjectBoxNonEscapingUse (pbi, state);
1127
1176
}
1128
1177
1129
1178
// Given a top level copy value use or mark_uninitialized, check all of its
@@ -1134,10 +1183,11 @@ static bool scanUsesForEscapesAndMutations(Operand *op,
1134
1183
// derived from a projection like instruction). In fact such a thing may not
1135
1184
// even make any sense!
1136
1185
if (isa<CopyValueInst>(user) || isa<MarkUninitializedInst>(user)) {
1137
- return all_of (cast<SingleValueInstruction>(user)->getUses (),
1138
- [&state](Operand *userOp) -> bool {
1139
- return scanUsesForEscapesAndMutations (userOp, state);
1140
- });
1186
+ bool foundSomeMutations = false ;
1187
+ for (auto *use : cast<SingleValueInstruction>(user)->getUses ()) {
1188
+ foundSomeMutations |= findEscapeOrMutationUses (use, state);
1189
+ }
1190
+ return foundSomeMutations;
1141
1191
}
1142
1192
1143
1193
// Verify that this use does not otherwise allow the alloc_box to
@@ -1153,14 +1203,20 @@ static bool
1153
1203
examineAllocBoxInst (AllocBoxInst *abi, ReachabilityInfo &ri,
1154
1204
llvm::DenseMap<PartialApplyInst *, unsigned > &im) {
1155
1205
LLVM_DEBUG (llvm::dbgs () << " Visiting alloc box: " << *abi);
1156
- EscapeMutationScanningState state{{}, false , im};
1206
+ EscapeMutationScanningState state{{}, {}, false , im};
1157
1207
1158
- // Scan the box for interesting uses.
1159
- if (any_of (abi->getUses (), [&state](Operand *op) {
1160
- return !scanUsesForEscapesAndMutations (op, state);
1161
- })) {
1208
+ // Scan the box for escaping or mutating uses.
1209
+ for (auto *use : abi->getUses ()) {
1210
+ findEscapeOrMutationUses (use, state);
1211
+ }
1212
+
1213
+ if (!state.accumulatedEscapes .empty ()) {
1162
1214
LLVM_DEBUG (llvm::dbgs ()
1163
- << " Found an escaping use! Can not optimize this alloc box?!\n " );
1215
+ << " Found escaping uses! Can not optimize this alloc box?!\n " );
1216
+ while (!state.accumulatedEscapes .empty ()) {
1217
+ auto *escapingUse = state.accumulatedEscapes .pop_back_val ();
1218
+ LLVM_DEBUG (llvm::dbgs () << " Escaping use: " << *escapingUse->getUser ());
1219
+ }
1164
1220
return false ;
1165
1221
}
1166
1222
@@ -1183,17 +1239,18 @@ examineAllocBoxInst(AllocBoxInst *abi, ReachabilityInfo &ri,
1183
1239
LLVM_DEBUG (llvm::dbgs ()
1184
1240
<< " Checking for any mutations that invalidate captures...\n " );
1185
1241
// Loop over all mutations to possibly invalidate captures.
1186
- for (auto *inst : state.foundMutations ) {
1242
+ for (auto *use : state.accumulatedMutations ) {
1187
1243
auto iter = im.begin ();
1188
1244
while (iter != im.end ()) {
1245
+ auto *user = use->getUser ();
1189
1246
auto *pai = iter->first ;
1190
1247
// The mutation invalidates a capture if it occurs in a block reachable
1191
1248
// from the block the partial_apply is in, or if it is in the same
1192
1249
// block is after the partial_apply.
1193
- if (ri.isReachable (pai->getParent (), inst ->getParent ()) ||
1194
- (pai->getParent () == inst ->getParent () && isAfter (pai, inst ))) {
1250
+ if (ri.isReachable (pai->getParent (), user ->getParent ()) ||
1251
+ (pai->getParent () == user ->getParent () && isAfter (pai, user ))) {
1195
1252
LLVM_DEBUG (llvm::dbgs () << " Invalidating: " << *pai);
1196
- LLVM_DEBUG (llvm::dbgs () << " Because of user: " << *inst );
1253
+ LLVM_DEBUG (llvm::dbgs () << " Because of user: " << *user );
1197
1254
auto prev = iter++;
1198
1255
im.erase (prev);
1199
1256
continue ;
0 commit comments