Skip to content

Commit 7ea69d1

Browse files
committed
Fix closure specializer for partial_apply [stack]
1 parent 0d9aced commit 7ea69d1

File tree

2 files changed

+164
-26
lines changed

2 files changed

+164
-26
lines changed

lib/SILOptimizer/IPO/ClosureSpecializer.cpp

Lines changed: 144 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@
7272
#include "swift/SILOptimizer/Utils/SILInliner.h"
7373
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
7474
#include "swift/SILOptimizer/Utils/SpecializationMangler.h"
75+
#include "swift/SILOptimizer/Utils/StackNesting.h"
7576
#include "llvm/ADT/SmallString.h"
77+
#include "llvm/ADT/SmallSet.h"
7678
#include "llvm/ADT/Statistic.h"
7779
#include "llvm/Support/CommandLine.h"
7880
#include "llvm/Support/Debug.h"
@@ -130,7 +132,8 @@ class ClosureSpecCloner : public SILClonerWithScopes<ClosureSpecCloner> {
130132
SILValue
131133
cloneCalleeConversion(SILValue calleeValue, SILValue NewClosure,
132134
SILBuilder &Builder,
133-
SmallVectorImpl<PartialApplyInst *> &NeedsRelease);
135+
SmallVectorImpl<PartialApplyInst *> &NeedsRelease,
136+
llvm::DenseMap<SILValue, SILValue> &CapturedMap);
134137

135138
SILFunction *getCloned() { return &getBuilder().getFunction(); }
136139
static SILFunction *cloneFunction(SILOptFunctionBuilder &FunctionBuilder,
@@ -192,7 +195,18 @@ class CallSiteDescriptor {
192195
}
193196

194197
bool closureHasRefSemanticContext() const {
195-
return isa<PartialApplyInst>(getClosure());
198+
return isa<PartialApplyInst>(getClosure()) &&
199+
!cast<PartialApplyInst>(getClosure())->isOnStack();
200+
}
201+
202+
bool destroyIfPartialApplyStack(SILBuilder &B,
203+
SingleValueInstruction *newClosure) const {
204+
auto *PA = dyn_cast<PartialApplyInst>(newClosure);
205+
if (!PA || !PA->isOnStack())
206+
return false;
207+
insertDestroyOfCapturedArguments(PA, B);
208+
B.createDeallocStack(getClosure()->getLoc(), PA);
209+
return true;
196210
}
197211

198212
unsigned getClosureIndex() const { return ClosureIndex; }
@@ -207,12 +221,13 @@ class CallSiteDescriptor {
207221
SingleValueInstruction *
208222
createNewClosure(SILBuilder &B, SILValue V,
209223
llvm::SmallVectorImpl<SILValue> &Args) const {
210-
if (isa<PartialApplyInst>(getClosure()))
224+
if (auto *PA = dyn_cast<PartialApplyInst>(getClosure()))
211225
return B.createPartialApply(getClosure()->getLoc(), V, {}, Args,
212226
getClosure()
213227
->getType()
214228
.getAs<SILFunctionType>()
215-
->getCalleeConvention());
229+
->getCalleeConvention(),
230+
PA->isOnStack());
216231

217232
assert(isa<ThinToThickFunctionInst>(getClosure()) &&
218233
"We only support partial_apply and thin_to_thick_function");
@@ -256,6 +271,13 @@ class CallSiteDescriptor {
256271
return getClosureParameterInfo().isConsumed();
257272
}
258273

274+
bool isClosureOnStack() const {
275+
auto *PA = dyn_cast<PartialApplyInst>(getClosure());
276+
if (!PA)
277+
return false;
278+
return PA->isOnStack();
279+
}
280+
259281
bool isTrivialNoEscapeParameter() const {
260282
auto ClosureParmFnTy =
261283
getClosureParameterInfo().getType()->getAs<SILFunctionType>();
@@ -675,16 +697,29 @@ ClosureSpecCloner::initCloned(SILOptFunctionBuilder &FunctionBuilder,
675697
// Clone a chain of ConvertFunctionInsts.
676698
SILValue ClosureSpecCloner::cloneCalleeConversion(
677699
SILValue calleeValue, SILValue NewClosure, SILBuilder &Builder,
678-
SmallVectorImpl<PartialApplyInst *> &NeedsRelease) {
700+
SmallVectorImpl<PartialApplyInst *> &NeedsRelease,
701+
llvm::DenseMap<SILValue, SILValue> &CapturedMap) {
702+
703+
// There might be a mark dependence on a previous closure value. Therefore, we
704+
// add all closure values to the map.
705+
auto addToOldToNewClosureMap = [&](SILValue origValue,
706+
SILValue newValue) -> SILValue {
707+
assert(!CapturedMap.count(origValue));
708+
CapturedMap[origValue] = newValue;
709+
return newValue;
710+
};
711+
679712
if (calleeValue == CallSiteDesc.getClosure())
680-
return NewClosure;
713+
return addToOldToNewClosureMap(calleeValue, NewClosure);
681714

682715
if (auto *CFI = dyn_cast<ConvertFunctionInst>(calleeValue)) {
716+
SILValue origCalleeValue = calleeValue;
683717
calleeValue = cloneCalleeConversion(CFI->getOperand(), NewClosure, Builder,
684-
NeedsRelease);
685-
return Builder.createConvertFunction(CallSiteDesc.getLoc(), calleeValue,
686-
CFI->getType(),
687-
CFI->withoutActuallyEscaping());
718+
NeedsRelease, CapturedMap);
719+
return addToOldToNewClosureMap(
720+
origCalleeValue, Builder.createConvertFunction(
721+
CallSiteDesc.getLoc(), calleeValue, CFI->getType(),
722+
CFI->withoutActuallyEscaping()));
688723
}
689724

690725
if (auto *PAI = dyn_cast<PartialApplyInst>(calleeValue)) {
@@ -693,27 +728,52 @@ SILValue ClosureSpecCloner::cloneCalleeConversion(
693728
->getType()
694729
.getAs<SILFunctionType>()
695730
->isTrivialNoEscape());
731+
SILValue origCalleeValue = calleeValue;
696732
calleeValue = cloneCalleeConversion(PAI->getArgument(0), NewClosure,
697-
Builder, NeedsRelease);
733+
Builder, NeedsRelease, CapturedMap);
698734
auto FunRef = Builder.createFunctionRef(CallSiteDesc.getLoc(),
699735
PAI->getReferencedFunction());
700736
auto NewPA = Builder.createPartialApply(
701737
CallSiteDesc.getLoc(), FunRef, {}, {calleeValue},
702-
PAI->getType().getAs<SILFunctionType>()->getCalleeConvention());
738+
PAI->getType().getAs<SILFunctionType>()->getCalleeConvention(),
739+
PAI->isOnStack());
740+
// If the partial_apply is on stack we will emit a dealloc_stack in the
741+
// epilog.
703742
NeedsRelease.push_back(NewPA);
704-
return NewPA;
743+
return addToOldToNewClosureMap(origCalleeValue, NewPA);
744+
}
745+
746+
if (auto *MD = dyn_cast<MarkDependenceInst>(calleeValue)) {
747+
SILValue origCalleeValue = calleeValue;
748+
calleeValue = cloneCalleeConversion(MD->getValue(), NewClosure, Builder,
749+
NeedsRelease, CapturedMap);
750+
if (!CapturedMap.count(MD->getBase())) {
751+
CallSiteDesc.getClosure()->dump();
752+
MD->dump();
753+
MD->getFunction()->dump();
754+
}
755+
assert(CapturedMap.count(MD->getBase()));
756+
return addToOldToNewClosureMap(
757+
origCalleeValue,
758+
Builder.createMarkDependence(CallSiteDesc.getLoc(), calleeValue,
759+
CapturedMap[MD->getBase()]));
705760
}
706761

762+
707763
auto *Cvt = cast<ConvertEscapeToNoEscapeInst>(calleeValue);
764+
SILValue origCalleeValue = calleeValue;
708765
calleeValue = cloneCalleeConversion(Cvt->getOperand(), NewClosure, Builder,
709-
NeedsRelease);
710-
return Builder.createConvertEscapeToNoEscape(
711-
CallSiteDesc.getLoc(), calleeValue, Cvt->getType(), true);
766+
NeedsRelease, CapturedMap);
767+
return addToOldToNewClosureMap(
768+
origCalleeValue,
769+
Builder.createConvertEscapeToNoEscape(CallSiteDesc.getLoc(), calleeValue,
770+
Cvt->getType(), true));
712771
}
713772

714773
/// Populate the body of the cloned closure, modifying instructions as
715-
/// necessary. This is where we create the actual specialized BB Arguments.
774+
/// necessary. This is where we create the actual specialized BB Arguments
716775
void ClosureSpecCloner::populateCloned() {
776+
bool needToUpdateStackNesting = false;
717777
SILFunction *Cloned = getCloned();
718778
SILFunction *ClosureUser = CallSiteDesc.getApplyCallee();
719779

@@ -752,10 +812,15 @@ void ClosureSpecCloner::populateCloned() {
752812
unsigned NumTotalParams = ClosedOverFunConv.getNumParameters();
753813
unsigned NumNotCaptured = NumTotalParams - CallSiteDesc.getNumArguments();
754814
llvm::SmallVector<SILValue, 4> NewPAIArgs;
815+
llvm::DenseMap<SILValue, SILValue> CapturedMap;
816+
unsigned idx = 0;
755817
for (auto &PInfo : ClosedOverFunConv.getParameters().slice(NumNotCaptured)) {
756818
auto paramTy = ClosedOverFunConv.getSILType(PInfo);
757819
SILValue MappedValue = ClonedEntryBB->createFunctionArgument(paramTy);
758820
NewPAIArgs.push_back(MappedValue);
821+
auto CapturedVal =
822+
cast<PartialApplyInst>(CallSiteDesc.getClosure())->getArgument(idx++);
823+
CapturedMap[CapturedVal] = MappedValue;
759824
}
760825

761826
SILBuilder &Builder = getBuilder();
@@ -770,8 +835,9 @@ void ClosureSpecCloner::populateCloned() {
770835
// Clone a chain of ConvertFunctionInsts. This can create further
771836
// reabstraction partial_apply instructions.
772837
SmallVector<PartialApplyInst*, 4> NeedsRelease;
773-
SILValue ConvertedCallee = cloneCalleeConversion(
774-
CallSiteDesc.getClosureCallerArg(), NewClosure, Builder, NeedsRelease);
838+
SILValue ConvertedCallee =
839+
cloneCalleeConversion(CallSiteDesc.getClosureCallerArg(), NewClosure,
840+
Builder, NeedsRelease, CapturedMap);
775841

776842
// Make sure that we actually emit the releases for reabstraction thunks. We
777843
// have guaranteed earlier that we only allow reabstraction thunks if the
@@ -790,7 +856,8 @@ void ClosureSpecCloner::populateCloned() {
790856
bool ClosureHasRefSemantics = CallSiteDesc.closureHasRefSemanticContext();
791857
if ((CallSiteDesc.isClosureGuaranteed() ||
792858
CallSiteDesc.isTrivialNoEscapeParameter()) &&
793-
(ClosureHasRefSemantics || !NeedsRelease.empty())) {
859+
(ClosureHasRefSemantics || !NeedsRelease.empty() ||
860+
CallSiteDesc.isClosureOnStack())) {
794861
for (SILBasicBlock *BB : CallSiteDesc.getNonFailureExitBBs()) {
795862
SILBasicBlock *OpBB = getOpBasicBlock(BB);
796863

@@ -804,9 +871,17 @@ void ClosureSpecCloner::populateCloned() {
804871
if (ClosureHasRefSemantics)
805872
Builder.createReleaseValue(Loc, SILValue(NewClosure),
806873
Builder.getDefaultAtomicity());
807-
for (auto PAI : NeedsRelease)
808-
Builder.createReleaseValue(Loc, SILValue(PAI),
809-
Builder.getDefaultAtomicity());
874+
else
875+
needToUpdateStackNesting |=
876+
CallSiteDesc.destroyIfPartialApplyStack(Builder, NewClosure);
877+
for (auto PAI : NeedsRelease) {
878+
if (PAI->isOnStack())
879+
needToUpdateStackNesting |=
880+
CallSiteDesc.destroyIfPartialApplyStack(Builder, PAI);
881+
else
882+
Builder.createReleaseValue(Loc, SILValue(PAI),
883+
Builder.getDefaultAtomicity());
884+
}
810885
continue;
811886
}
812887

@@ -825,11 +900,22 @@ void ClosureSpecCloner::populateCloned() {
825900
if (ClosureHasRefSemantics)
826901
Builder.createReleaseValue(Loc, SILValue(NewClosure),
827902
Builder.getDefaultAtomicity());
828-
for (auto PAI : NeedsRelease)
829-
Builder.createReleaseValue(Loc, SILValue(PAI),
830-
Builder.getDefaultAtomicity());
903+
else
904+
needToUpdateStackNesting |=
905+
CallSiteDesc.destroyIfPartialApplyStack(Builder, NewClosure);
906+
for (auto PAI : NeedsRelease) {
907+
if (PAI->isOnStack())
908+
needToUpdateStackNesting |=
909+
CallSiteDesc.destroyIfPartialApplyStack(Builder, PAI);
910+
else
911+
Builder.createReleaseValue(Loc, SILValue(PAI),
912+
Builder.getDefaultAtomicity());
913+
}
831914
}
832915
}
916+
if (needToUpdateStackNesting) {
917+
StackNesting().correctStackNesting(Cloned);
918+
}
833919
}
834920

835921
//===----------------------------------------------------------------------===//
@@ -913,6 +999,10 @@ static void markReabstractionPartialApplyAsUsed(
913999
return markReabstractionPartialApplyAsUsed(FirstClosure, Cvt->getOperand(),
9141000
UsedReabstractionClosure);
9151001
}
1002+
if (auto MD = dyn_cast<MarkDependenceInst>(Current)) {
1003+
return markReabstractionPartialApplyAsUsed(FirstClosure, MD->getValue(),
1004+
UsedReabstractionClosure);
1005+
}
9161006
llvm_unreachable("Unexpect instruction");
9171007
}
9181008

@@ -994,6 +1084,14 @@ bool SILClosureSpecializerTransform::gatherCallSites(
9941084
// Live range end points.
9951085
SmallVector<SILInstruction *, 8> UsePoints;
9961086

1087+
// Set of possible arguments for mark_dependence. The base of a
1088+
// mark_dependence we copy must be available in the specialized function.
1089+
llvm::SmallSet<SILValue, 16> PossibleMarkDependenceBases;
1090+
if (auto *PA = dyn_cast<PartialApplyInst>(ClosureInst)) {
1091+
for (auto Opd : PA->getArguments())
1092+
PossibleMarkDependenceBases.insert(Opd);
1093+
}
1094+
9971095
bool HaveUsedReabstraction = false;
9981096
// Uses may grow in this loop.
9991097
for (size_t UseIndex = 0; UseIndex < Uses.size(); ++UseIndex) {
@@ -1004,10 +1102,12 @@ bool SILClosureSpecializerTransform::gatherCallSites(
10041102
if (auto *CFI = dyn_cast<ConvertFunctionInst>(Use->getUser())) {
10051103
// Push Uses in reverse order so they are visited in forward order.
10061104
Uses.append(CFI->getUses().begin(), CFI->getUses().end());
1105+
PossibleMarkDependenceBases.insert(CFI);
10071106
continue;
10081107
}
10091108
if (auto *Cvt = dyn_cast<ConvertEscapeToNoEscapeInst>(Use->getUser())) {
10101109
Uses.append(Cvt->getUses().begin(), Cvt->getUses().end());
1110+
PossibleMarkDependenceBases.insert(Cvt);
10111111
continue;
10121112
}
10131113

@@ -1025,11 +1125,29 @@ bool SILClosureSpecializerTransform::gatherCallSites(
10251125
.getAs<SILFunctionType>()
10261126
->isTrivialNoEscape()) {
10271127
Uses.append(PA->getUses().begin(), PA->getUses().end());
1128+
PossibleMarkDependenceBases.insert(PA);
10281129
HaveUsedReabstraction = true;
10291130
}
10301131
continue;
10311132
}
10321133

1134+
// Look through mark_dependence on partial_apply [stack].
1135+
if (auto *MD = dyn_cast<MarkDependenceInst>(Use->getUser())) {
1136+
// We can't copy a closure if the mark_dependence base is not
1137+
// available in the specialized function.
1138+
if (!PossibleMarkDependenceBases.count(MD->getBase()))
1139+
continue;
1140+
if (MD->getValue() == Use->get() &&
1141+
MD->getValue()->getType().is<SILFunctionType>() &&
1142+
MD->getValue()
1143+
->getType()
1144+
.castTo<SILFunctionType>()
1145+
->isTrivialNoEscape()) {
1146+
Uses.append(MD->getUses().begin(), MD->getUses().end());
1147+
continue;
1148+
}
1149+
}
1150+
10331151
// If this use is not a full apply site that we can process or an apply
10341152
// inst with substitutions, there is nothing interesting for us to do,
10351153
// so continue...

test/SILOptimizer/closure_specialize.sil

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,3 +808,23 @@ bb3:
808808
release_value %2 : $@callee_guaranteed () -> Int
809809
return %value : $Int
810810
}
811+
// CHECK-LABEL: sil @reabstractionTest_on_stack
812+
// CHECK: bb0([[A:%.*]] : $Int):
813+
// CHECK: [[R:%.*]] = alloc_stack $Int
814+
// CHECK: [[F:%.*]] = function_ref @$s25testClosureThunkNoEscape20aB14ConvertHelper2SiTf1nc_n
815+
// CHECK: apply [[F]]([[R]], [[A]])
816+
sil @reabstractionTest_on_stack : $(Int) -> () {
817+
bb0(%0 : $Int):
818+
%48 = alloc_stack $Int
819+
%49 = function_ref @testClosureConvertHelper2 : $@convention(thin) (Int) -> Int
820+
%50 = partial_apply [callee_guaranteed] [on_stack] %49(%0) : $@convention(thin) (Int) -> Int
821+
%52 = function_ref @reabstractionThunk : $@convention(thin) (@noescape @callee_guaranteed () -> Int) -> @out Int
822+
%53 = partial_apply [callee_guaranteed] [on_stack] %52(%50) : $@convention(thin) (@noescape @callee_guaranteed () -> Int) -> @out Int
823+
%55 = function_ref @testClosureThunkNoEscape2 : $@convention(thin) (@noescape @callee_guaranteed () -> @out Int) -> @out Int
824+
apply %55(%48, %53) : $@convention(thin) (@noescape @callee_guaranteed () -> @out Int) -> @out Int
825+
dealloc_stack %53 : $@noescape @callee_guaranteed () -> @out Int
826+
dealloc_stack %50 : $@noescape @callee_guaranteed () -> Int
827+
dealloc_stack %48 : $*Int
828+
%empty = tuple ()
829+
return %empty : $()
830+
}

0 commit comments

Comments
 (0)