Skip to content

Commit 34d79e3

Browse files
committed
CapturePropagation: specialize closures which capture a constant keypath
This optimizes keypath-closures, like ``` a.map { \.x } ``` It results in a significant performance improvement for such code patterns. rdar://87968067
1 parent 13d2b1f commit 34d79e3

File tree

7 files changed

+362
-29
lines changed

7 files changed

+362
-29
lines changed

include/swift/SIL/SILInstruction.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class StringLiteralExpr;
7878
class ValueDecl;
7979
class VarDecl;
8080
class FunctionRefBaseInst;
81+
class SILPrintContext;
8182

8283
template <typename ImplClass> class SILClonerWithScopes;
8384

@@ -3489,7 +3490,9 @@ class KeyPathPatternComponent {
34893490

34903491
void incrementRefCounts() const;
34913492
void decrementRefCounts() const;
3492-
3493+
3494+
void print(SILPrintContext &ctxt) const;
3495+
34933496
void Profile(llvm::FoldingSetNodeID &ID);
34943497
};
34953498

include/swift/SILOptimizer/Utils/SpecializationMangler.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ class FunctionSignatureSpecializationMangler : public SpecializationMangler {
9595
FunctionSignatureSpecializationMangler(SpecializationPass Pass,
9696
IsSerialized_t Serialized,
9797
SILFunction *F);
98-
void setArgumentConstantProp(unsigned OrigArgIdx, LiteralInst *LI);
98+
void setArgumentConstantProp(unsigned OrigArgIdx, SILInstruction *constInst);
99+
void appendStringAsIdentifier(StringRef str);
100+
99101
void setArgumentClosureProp(unsigned OrigArgIdx, PartialApplyInst *PAI);
100102
void setArgumentClosureProp(unsigned OrigArgIdx,
101103
ThinToThickFunctionInst *TTTFI);
@@ -112,7 +114,7 @@ class FunctionSignatureSpecializationMangler : public SpecializationMangler {
112114
std::string mangle();
113115

114116
private:
115-
void mangleConstantProp(LiteralInst *LI);
117+
void mangleConstantProp(SILInstruction *constInst);
116118
void mangleClosureProp(SILInstruction *Inst);
117119
void mangleArgument(ArgumentModifierIntBase ArgMod,
118120
NullablePtr<SILInstruction> Inst);

lib/SIL/IR/SILPrinter.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3855,6 +3855,11 @@ void SILSpecializeAttr::print(llvm::raw_ostream &OS) const {
38553855
}
38563856
}
38573857

3858+
void KeyPathPatternComponent::print(SILPrintContext &ctxt) const {
3859+
SILPrinter printer(ctxt);
3860+
printer.printKeyPathPatternComponent(*this);
3861+
}
3862+
38583863
//===----------------------------------------------------------------------===//
38593864
// SILPrintContext members
38603865
//===----------------------------------------------------------------------===//

lib/SILOptimizer/IPO/CapturePropagation.cpp

Lines changed: 93 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -50,25 +50,38 @@ class CapturePropagation : public SILFunctionTransform
5050
};
5151
} // end anonymous namespace
5252

53-
static LiteralInst *getConstant(SILValue V) {
53+
static SILInstruction *getConstant(SILValue V) {
5454
if (auto I = dyn_cast<ThinToThickFunctionInst>(V))
5555
return getConstant(I->getOperand());
5656
if (auto I = dyn_cast<ConvertFunctionInst>(V))
5757
return getConstant(I->getOperand());
58-
return dyn_cast<LiteralInst>(V);
59-
}
6058

61-
static bool isOptimizableConstant(SILValue V) {
62-
// We do not optimize string literals of length > 32 since we would need to
63-
// encode them into the symbol name for uniqueness.
64-
if (auto *SLI = dyn_cast<StringLiteralInst>(V))
65-
return SLI->getValue().size() <= 32;
66-
return true;
67-
}
59+
if (auto *SLI = dyn_cast<StringLiteralInst>(V)) {
60+
// We do not optimize string literals of length > 32 since we would need to
61+
// encode them into the symbol name for uniqueness.
62+
if (SLI->getValue().size() > 32)
63+
return nullptr;
64+
return SLI;
65+
}
66+
67+
if (auto *lit = dyn_cast<LiteralInst>(V))
68+
return lit;
6869

69-
static bool isConstant(SILValue V) {
70-
V = getConstant(V);
71-
return V && isOptimizableConstant(V);
70+
if (auto *kp = dyn_cast<KeyPathInst>(V)) {
71+
// We could support operands, if they are constants, to enable propagation
72+
// of subscript keypaths. This would require to add the operands in the
73+
// mangling scheme.
74+
// But currently it's not worth it because we do not optimize subscript
75+
// keypaths in SILCombine.
76+
if (kp->getNumOperands() != 0)
77+
return nullptr;
78+
if (!kp->hasPattern())
79+
return nullptr;
80+
if (kp->getSubstitutions().hasAnySubstitutableParams())
81+
return nullptr;
82+
return kp;
83+
}
84+
return nullptr;
7285
}
7386

7487
static std::string getClonedName(PartialApplyInst *PAI, IsSerialized_t Serialized,
@@ -194,11 +207,15 @@ void CapturePropagationCloner::cloneClosure(
194207
// Replace the rest of the old arguments with constants.
195208
getBuilder().setInsertionPoint(ClonedEntryBB);
196209
IsCloningConstant = true;
210+
llvm::SmallVector<KeyPathInst *, 8> toDestroy;
197211
for (SILValue PartialApplyArg : PartialApplyArgs) {
198-
assert(isConstant(PartialApplyArg) &&
212+
assert(getConstant(PartialApplyArg) &&
199213
"expected a constant arg to partial apply");
200214

201215
cloneConstValue(PartialApplyArg);
216+
if (auto *kp = dyn_cast<KeyPathInst>(getMappedValue(PartialApplyArg))) {
217+
toDestroy.push_back(kp);
218+
}
202219

203220
// The PartialApplyArg from the caller is now mapped to its cloned
204221
// instruction. Also map the original argument to the cloned instruction.
@@ -213,6 +230,22 @@ void CapturePropagationCloner::cloneClosure(
213230
// Visit original BBs in depth-first preorder, starting with the
214231
// entry block, cloning all instructions and terminators.
215232
cloneFunctionBody(OrigF, ClonedEntryBB, entryArgs);
233+
234+
// Destroy all the inserted keypaths at the function exits.
235+
for (KeyPathInst *kpToDestroy : toDestroy) {
236+
SILLocation loc = RegularLocation::getAutoGeneratedLocation();
237+
for (SILBasicBlock &clonedBB : CloneF) {
238+
TermInst *term = clonedBB.getTerminator();
239+
if (term->isFunctionExiting()) {
240+
SILBuilder builder(term);
241+
if (CloneF.hasOwnership()) {
242+
builder.createDestroyValue(loc, kpToDestroy);
243+
} else {
244+
builder.createStrongRelease(loc, kpToDestroy, builder.getDefaultAtomicity());
245+
}
246+
}
247+
}
248+
}
216249
}
217250

218251
CanSILFunctionType getPartialApplyInterfaceResultType(PartialApplyInst *PAI) {
@@ -305,6 +338,20 @@ void CapturePropagation::rewritePartialApply(PartialApplyInst *OrigPAI,
305338
LLVM_DEBUG(llvm::dbgs() << " Rewrote caller:\n" << *T2TF);
306339
}
307340

341+
static bool isKeyPathFunction(FullApplySite FAS, SILValue keyPath) {
342+
SILFunction *callee = FAS.getReferencedFunctionOrNull();
343+
if (!callee)
344+
return false;
345+
if (callee->getName() == "swift_setAtWritableKeyPath" ||
346+
callee->getName() == "swift_setAtReferenceWritableKeyPath") {
347+
return FAS.getArgument(1) == keyPath;
348+
}
349+
if (callee->getName() == "swift_getAtKeyPath") {
350+
return FAS.getArgument(2) == keyPath;
351+
}
352+
return false;
353+
}
354+
308355
/// For now, we conservative only specialize if doing so can eliminate dynamic
309356
/// dispatch.
310357
///
@@ -316,6 +363,8 @@ static bool isProfitable(SILFunction *Callee) {
316363
if (FullApplySite FAS = FullApplySite::isa(Operand->getUser())) {
317364
if (FAS.getCallee() == Operand->get())
318365
return true;
366+
if (isKeyPathFunction(FAS, Arg))
367+
return true;
319368
}
320369
}
321370
}
@@ -477,9 +526,35 @@ bool CapturePropagation::optimizePartialApply(PartialApplyInst *PAI) {
477526
}
478527

479528
// Second possibility: Are all partially applied arguments constant?
480-
for (auto Arg : PAI->getArguments()) {
481-
if (!isConstant(Arg))
529+
llvm::SmallVector<SILInstruction *, 8> toDelete;
530+
for (const Operand &argOp : PAI->getArgumentOperands()) {
531+
SILInstruction *constInst = getConstant(argOp.get());
532+
if (!constInst)
482533
return false;
534+
if (auto *kp = dyn_cast<KeyPathInst>(constInst)) {
535+
auto argConv = ApplySite(PAI).getArgumentConvention(argOp).Value;
536+
// Only handle the common case of a guaranteed keypath arguments. That
537+
// refers to the callee function.
538+
if (argConv != SILArgumentConvention::Direct_Guaranteed)
539+
return false;
540+
541+
// For escaping closures:
542+
// To keep things simple, we don't do a liferange analysis to insert
543+
// compensating destroys of the keypath.
544+
// Instead we require that the PAI is the only use of the keypath (= the
545+
// common case). This allows us to just delete the now unused keypath
546+
// instruction.
547+
//
548+
// For non-escaping closures:
549+
// The keypath is not consumed by the PAI. We don't need todelete the
550+
// keypath instruction in this pass, but let dead-object-elimination clean
551+
// it up later.
552+
if (!PAI->isOnStack()) {
553+
if (getSingleNonDebugUser(kp) != PAI)
554+
return false;
555+
toDelete.push_back(kp);
556+
}
557+
}
483558
}
484559
if (!isProfitable(SubstF))
485560
return false;
@@ -491,6 +566,8 @@ bool CapturePropagation::optimizePartialApply(PartialApplyInst *PAI) {
491566
SILFunction *NewF = specializeConstClosure(PAI, SubstF);
492567
rewritePartialApply(PAI, NewF);
493568

569+
recursivelyDeleteTriviallyDeadInstructions(toDelete, /*force*/ true);
570+
494571
addFunctionToPassManagerWorklist(NewF, SubstF);
495572
return true;
496573
}

lib/SILOptimizer/Utils/SpecializationMangler.cpp

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
#include "swift/AST/GenericEnvironment.h"
1515
#include "swift/AST/GenericSignature.h"
1616
#include "swift/AST/SubstitutionMap.h"
17+
#include "swift/Basic/MD5Stream.h"
1718
#include "swift/Demangling/ManglingMacros.h"
1819
#include "swift/SIL/SILGlobalVariable.h"
20+
#include "llvm/ADT/StringExtras.h"
1921

2022
using namespace swift;
2123
using namespace Mangle;
@@ -68,10 +70,10 @@ void FunctionSignatureSpecializationMangler::setArgumentClosureProp(
6870
}
6971

7072
void FunctionSignatureSpecializationMangler::setArgumentConstantProp(
71-
unsigned OrigArgIdx, LiteralInst *LI) {
73+
unsigned OrigArgIdx, SILInstruction *constInst) {
7274
auto &Info = OrigArgs[OrigArgIdx];
7375
Info.first = ArgumentModifierIntBase(ArgumentModifier::ConstantProp);
74-
Info.second = LI;
76+
Info.second = constInst;
7577
}
7678

7779
void FunctionSignatureSpecializationMangler::setArgumentOwnedToGuaranteed(
@@ -122,41 +124,41 @@ setReturnValueOwnedToUnowned() {
122124
}
123125

124126
void
125-
FunctionSignatureSpecializationMangler::mangleConstantProp(LiteralInst *LI) {
127+
FunctionSignatureSpecializationMangler::mangleConstantProp(SILInstruction *constInst) {
126128
// Append the prefix for constant propagation 'p'.
127129
ArgOpBuffer << 'p';
128130

129131
// Then append the unique identifier of our literal.
130-
switch (LI->getKind()) {
132+
switch (constInst->getKind()) {
131133
default:
132134
llvm_unreachable("unknown literal");
133135
case SILInstructionKind::PreviousDynamicFunctionRefInst:
134136
case SILInstructionKind::DynamicFunctionRefInst:
135137
case SILInstructionKind::FunctionRefInst: {
136138
SILFunction *F =
137-
cast<FunctionRefBaseInst>(LI)->getInitiallyReferencedFunction();
139+
cast<FunctionRefBaseInst>(constInst)->getInitiallyReferencedFunction();
138140
ArgOpBuffer << 'f';
139141
appendIdentifier(F->getName());
140142
break;
141143
}
142144
case SILInstructionKind::GlobalAddrInst: {
143-
SILGlobalVariable *G = cast<GlobalAddrInst>(LI)->getReferencedGlobal();
145+
SILGlobalVariable *G = cast<GlobalAddrInst>(constInst)->getReferencedGlobal();
144146
ArgOpBuffer << 'g';
145147
appendIdentifier(G->getName());
146148
break;
147149
}
148150
case SILInstructionKind::IntegerLiteralInst: {
149-
APInt apint = cast<IntegerLiteralInst>(LI)->getValue();
151+
APInt apint = cast<IntegerLiteralInst>(constInst)->getValue();
150152
ArgOpBuffer << 'i' << apint;
151153
break;
152154
}
153155
case SILInstructionKind::FloatLiteralInst: {
154-
APInt apint = cast<FloatLiteralInst>(LI)->getBits();
156+
APInt apint = cast<FloatLiteralInst>(constInst)->getBits();
155157
ArgOpBuffer << 'd' << apint;
156158
break;
157159
}
158160
case SILInstructionKind::StringLiteralInst: {
159-
StringLiteralInst *SLI = cast<StringLiteralInst>(LI);
161+
StringLiteralInst *SLI = cast<StringLiteralInst>(constInst);
160162
StringRef V = SLI->getValue();
161163
assert(V.size() <= 32 && "Cannot encode string of length > 32");
162164
std::string VBuffer;
@@ -175,7 +177,43 @@ FunctionSignatureSpecializationMangler::mangleConstantProp(LiteralInst *LI) {
175177
}
176178
break;
177179
}
180+
case SILInstructionKind::KeyPathInst: {
181+
// Mangle a keypath instruction by creating a MD5 hash of the printed
182+
// instruction. Everything else would be too complicated.
183+
184+
auto *kp = cast<KeyPathInst>(constInst);
185+
KeyPathPattern *pattern = kp->getPattern();
186+
187+
MD5Stream md5Stream;
188+
SILPrintContext printCtxt(md5Stream);
189+
for (auto &component : pattern->getComponents()) {
190+
component.print(printCtxt);
191+
}
192+
llvm::MD5::MD5Result md5Hash;
193+
md5Stream.final(md5Hash);
194+
SmallString<32> resultStr;
195+
llvm::MD5::stringifyResult(md5Hash, resultStr);
196+
appendStringAsIdentifier(resultStr);
197+
198+
// Also, mangle the involved types.
199+
appendType(pattern->getRootType(), nullptr);
200+
appendType(pattern->getValueType(), nullptr);
201+
202+
ArgOpBuffer << 'k';
203+
break;
204+
}
205+
}
206+
}
207+
208+
void
209+
FunctionSignatureSpecializationMangler::appendStringAsIdentifier(StringRef str) {
210+
std::string buffer;
211+
if (!str.empty() && (isDigit(str[0]) || str[0] == '_')) {
212+
buffer = "_";
213+
buffer.append(str.data(), str.size());
214+
str = buffer;
178215
}
216+
appendIdentifier(str);
179217
}
180218

181219
void
@@ -206,7 +244,7 @@ FunctionSignatureSpecializationMangler::mangleClosureProp(SILInstruction *Inst)
206244
void FunctionSignatureSpecializationMangler::mangleArgument(
207245
ArgumentModifierIntBase ArgMod, NullablePtr<SILInstruction> Inst) {
208246
if (ArgMod == ArgumentModifierIntBase(ArgumentModifier::ConstantProp)) {
209-
mangleConstantProp(cast<LiteralInst>(Inst.get()));
247+
mangleConstantProp(Inst.get());
210248
return;
211249
}
212250

0 commit comments

Comments
 (0)