Skip to content

Commit 00b795f

Browse files
authored
Merge pull request #41120 from eeckstein/capture-propagate-keypaths
CapturePropagation: specialize closures which capture a constant keypath
2 parents f6f3869 + 34d79e3 commit 00b795f

File tree

15 files changed

+441
-55
lines changed

15 files changed

+441
-55
lines changed

docs/ABI/Mangling.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,6 +1139,7 @@ Some kinds need arguments, which precede ``Tf``.
11391139
CONST-PROP ::= 'i' NATURAL_ZERO // 64-bit-integer
11401140
CONST-PROP ::= 'd' NATURAL_ZERO // float-as-64-bit-integer
11411141
CONST-PROP ::= 's' ENCODING // string literal. Consumes one identifier argument.
1142+
CONST-PROP ::= 'k' // keypath. Consumes one identifier - the SHA1 of the keypath and two types (root and value).
11421143

11431144
ENCODING ::= 'b' // utf8
11441145
ENCODING ::= 'w' // utf16

include/swift/Basic/MD5Stream.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//===--- MD5Stream.h - raw_ostream that compute MD5 ------------*- C++ -*-===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2022 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef SWIFT_MD5STREAM_H
14+
#define SWIFT_MD5STREAM_H
15+
16+
#include "llvm/Support/MD5.h"
17+
#include "llvm/Support/raw_ostream.h"
18+
19+
namespace swift {
20+
21+
/// An output stream which calculates the MD5 hash of the streamed data.
22+
class MD5Stream : public llvm::raw_ostream {
23+
private:
24+
25+
uint64_t Pos = 0;
26+
llvm::MD5 Hash;
27+
28+
void write_impl(const char *Ptr, size_t Size) override {
29+
Hash.update(ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(Ptr), Size));
30+
Pos += Size;
31+
}
32+
33+
uint64_t current_pos() const override { return Pos; }
34+
35+
public:
36+
37+
void final(llvm::MD5::MD5Result &Result) {
38+
flush();
39+
Hash.final(Result);
40+
}
41+
};
42+
43+
} // namespace swift
44+
45+
#endif

include/swift/Demangling/Demangle.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ enum class FunctionSigSpecializationParamKind : unsigned {
109109
BoxToValue = 6,
110110
BoxToStack = 7,
111111
InOutToOut = 8,
112+
ConstantPropKeyPath = 9,
112113

113114
// Option Set Flags use bits 6-31. This gives us 26 bits to use for option
114115
// flags.

include/swift/SIL/SILInstruction.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class StringLiteralExpr;
7878
class ValueDecl;
7979
class VarDecl;
8080
class FunctionRefBaseInst;
81+
class SILPrintContext;
8182

8283
template <typename ImplClass> class SILClonerWithScopes;
8384

@@ -3489,7 +3490,9 @@ class KeyPathPatternComponent {
34893490

34903491
void incrementRefCounts() const;
34913492
void decrementRefCounts() const;
3492-
3493+
3494+
void print(SILPrintContext &ctxt) const;
3495+
34933496
void Profile(llvm::FoldingSetNodeID &ID);
34943497
};
34953498

include/swift/SILOptimizer/Utils/SpecializationMangler.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ class FunctionSignatureSpecializationMangler : public SpecializationMangler {
9595
FunctionSignatureSpecializationMangler(SpecializationPass Pass,
9696
IsSerialized_t Serialized,
9797
SILFunction *F);
98-
void setArgumentConstantProp(unsigned OrigArgIdx, LiteralInst *LI);
98+
void setArgumentConstantProp(unsigned OrigArgIdx, SILInstruction *constInst);
99+
void appendStringAsIdentifier(StringRef str);
100+
99101
void setArgumentClosureProp(unsigned OrigArgIdx, PartialApplyInst *PAI);
100102
void setArgumentClosureProp(unsigned OrigArgIdx,
101103
ThinToThickFunctionInst *TTTFI);
@@ -112,7 +114,7 @@ class FunctionSignatureSpecializationMangler : public SpecializationMangler {
112114
std::string mangle();
113115

114116
private:
115-
void mangleConstantProp(LiteralInst *LI);
117+
void mangleConstantProp(SILInstruction *constInst);
116118
void mangleClosureProp(SILInstruction *Inst);
117119
void mangleArgument(ArgumentModifierIntBase ArgMod,
118120
NullablePtr<SILInstruction> Inst);

lib/Demangling/Demangler.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2816,10 +2816,12 @@ NodePointer Demangler::demangleFunctionSpecialization() {
28162816
case FunctionSigSpecializationParamKind::ConstantPropFunction:
28172817
case FunctionSigSpecializationParamKind::ConstantPropGlobal:
28182818
case FunctionSigSpecializationParamKind::ConstantPropString:
2819+
case FunctionSigSpecializationParamKind::ConstantPropKeyPath:
28192820
case FunctionSigSpecializationParamKind::ClosureProp: {
28202821
size_t FixedChildren = Param->getNumChildren();
28212822
while (NodePointer Ty = popNode(Node::Kind::Type)) {
2822-
if (ParamKind != FunctionSigSpecializationParamKind::ClosureProp)
2823+
if (ParamKind != FunctionSigSpecializationParamKind::ClosureProp &&
2824+
ParamKind != FunctionSigSpecializationParamKind::ConstantPropKeyPath)
28232825
return nullptr;
28242826
Param = addChild(Param, Ty);
28252827
}
@@ -2901,6 +2903,14 @@ NodePointer Demangler::demangleFuncSpecParam(Node::Kind Kind) {
29012903
Node::Kind::FunctionSignatureSpecializationParamPayload,
29022904
Encoding));
29032905
}
2906+
case 'k': {
2907+
// Consumes two types and a SHA1 identifier.
2908+
return addChild(
2909+
Param,
2910+
createNode(Node::Kind::FunctionSignatureSpecializationParamKind,
2911+
Node::IndexType(FunctionSigSpecializationParamKind::
2912+
ConstantPropKeyPath)));
2913+
}
29042914
default:
29052915
return nullptr;
29062916
}

lib/Demangling/NodePrinter.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,6 +1054,17 @@ void NodePrinter::printFunctionSigSpecializationParams(NodePointer Node,
10541054
Printer << "'";
10551055
Printer << "]";
10561056
break;
1057+
case FunctionSigSpecializationParamKind::ConstantPropKeyPath:
1058+
Printer << "[";
1059+
print(Node->getChild(Idx++), depth + 1);
1060+
Printer << " : ";
1061+
print(Node->getChild(Idx++), depth + 1);
1062+
Printer << "<";
1063+
print(Node->getChild(Idx++), depth + 1);
1064+
Printer << ",";
1065+
print(Node->getChild(Idx++), depth + 1);
1066+
Printer << ">]";
1067+
break;
10571068
case FunctionSigSpecializationParamKind::ClosureProp:
10581069
Printer << "[";
10591070
print(Node->getChild(Idx++), depth + 1);
@@ -1611,6 +1622,9 @@ NodePointer NodePrinter::print(NodePointer Node, unsigned depth,
16111622
case FunctionSigSpecializationParamKind::ConstantPropString:
16121623
Printer << "Constant Propagated String";
16131624
return nullptr;
1625+
case FunctionSigSpecializationParamKind::ConstantPropKeyPath:
1626+
Printer << "Constant Propagated KeyPath";
1627+
return nullptr;
16141628
case FunctionSigSpecializationParamKind::ClosureProp:
16151629
Printer << "Closure Propagated";
16161630
return nullptr;

lib/Demangling/Remangler.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,6 +1322,7 @@ ManglingError Remangler::mangleFunctionSignatureSpecialization(Node *node,
13221322
break;
13231323
}
13241324
case FunctionSigSpecializationParamKind::ClosureProp:
1325+
case FunctionSigSpecializationParamKind::ConstantPropKeyPath:
13251326
RETURN_IF_ERROR(mangleIdentifier(Param->getChild(1), depth + 1));
13261327
for (unsigned i = 2, e = Param->getNumChildren(); i != e; ++i) {
13271328
RETURN_IF_ERROR(mangleType(Param->getChild(i), depth + 1));
@@ -1399,6 +1400,9 @@ Remangler::mangleFunctionSignatureSpecializationParam(Node *node,
13991400
}
14001401
break;
14011402
}
1403+
case FunctionSigSpecializationParamKind::ConstantPropKeyPath:
1404+
Buffer << "pk";
1405+
break;
14021406
case FunctionSigSpecializationParamKind::ClosureProp:
14031407
Buffer << 'c';
14041408
break;

lib/IRGen/IRGen.cpp

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "swift/AST/TBDGenRequests.h"
2828
#include "swift/Basic/Defer.h"
2929
#include "swift/Basic/Dwarf.h"
30+
#include "swift/Basic/MD5Stream.h"
3031
#include "swift/Basic/Platform.h"
3132
#include "swift/Basic/Statistic.h"
3233
#include "swift/Basic/Version.h"
@@ -69,7 +70,6 @@
6970
#include "llvm/Support/ErrorHandling.h"
7071
#include "llvm/Support/FileSystem.h"
7172
#include "llvm/Support/FormattedStream.h"
72-
#include "llvm/Support/MD5.h"
7373
#include "llvm/Support/Mutex.h"
7474
#include "llvm/Support/Path.h"
7575
#include "llvm/Target/TargetMachine.h"
@@ -376,30 +376,6 @@ void swift::performLLVMOptimizations(const IRGenOptions &Opts,
376376
}
377377
}
378378

379-
namespace {
380-
/// An output stream which calculates the MD5 hash of the streamed data.
381-
class MD5Stream : public llvm::raw_ostream {
382-
private:
383-
384-
uint64_t Pos = 0;
385-
llvm::MD5 Hash;
386-
387-
void write_impl(const char *Ptr, size_t Size) override {
388-
Hash.update(ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(Ptr), Size));
389-
Pos += Size;
390-
}
391-
392-
uint64_t current_pos() const override { return Pos; }
393-
394-
public:
395-
396-
void final(MD5::MD5Result &Result) {
397-
flush();
398-
Hash.final(Result);
399-
}
400-
};
401-
} // end anonymous namespace
402-
403379
/// Computes the MD5 hash of the llvm \p Module including the compiler version
404380
/// and options which influence the compilation.
405381
static MD5::MD5Result getHashOfModule(const IRGenOptions &Opts,

lib/SIL/IR/SILPrinter.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3861,6 +3861,11 @@ void SILSpecializeAttr::print(llvm::raw_ostream &OS) const {
38613861
}
38623862
}
38633863

3864+
void KeyPathPatternComponent::print(SILPrintContext &ctxt) const {
3865+
SILPrinter printer(ctxt);
3866+
printer.printKeyPathPatternComponent(*this);
3867+
}
3868+
38643869
//===----------------------------------------------------------------------===//
38653870
// SILPrintContext members
38663871
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)