Skip to content

Commit a7f8d6c

Browse files
committed
[Autodiff] Adds bridging code in preparation for the Swift based Autodiff closure-spec pass
1 parent 73ed03c commit a7f8d6c

File tree

10 files changed

+239
-87
lines changed

10 files changed

+239
-87
lines changed

include/swift/SIL/SILBridging.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ struct BridgedType {
347347
BRIDGED_INLINE bool isMetatype() const;
348348
BRIDGED_INLINE bool isNoEscapeFunction() const;
349349
BRIDGED_INLINE bool containsNoEscapeFunction() const;
350+
BRIDGED_INLINE bool isThickFunction() const;
350351
BRIDGED_INLINE bool isAsyncFunction() const;
351352
BRIDGED_INLINE bool isEmpty(BridgedFunction f) const;
352353
BRIDGED_INLINE TraitResult canBeClass() const;
@@ -373,7 +374,8 @@ struct BridgedType {
373374
BRIDGED_INLINE bool isEndCaseIterator(EnumElementIterator i) const;
374375
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedType getEnumCasePayload(EnumElementIterator i, BridgedFunction f) const;
375376
BRIDGED_INLINE SwiftInt getNumTupleElements() const;
376-
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedType getTupleElementType(SwiftInt idx) const;
377+
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedType
378+
getTupleElementType(SwiftInt idx) const;
377379
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedType getFunctionTypeWithNoEscape(bool withNoEscape) const;
378380
};
379381

@@ -547,13 +549,15 @@ struct BridgedFunction {
547549
BRIDGED_INLINE bool isAvailableExternally() const;
548550
BRIDGED_INLINE bool isTransparent() const;
549551
BRIDGED_INLINE bool isAsync() const;
552+
BRIDGED_INLINE bool isReabstractionThunk() const;
550553
BRIDGED_INLINE bool isGlobalInitFunction() const;
551554
BRIDGED_INLINE bool isGlobalInitOnceFunction() const;
552555
BRIDGED_INLINE bool isDestructor() const;
553556
BRIDGED_INLINE bool isGeneric() const;
554557
BRIDGED_INLINE bool hasSemanticsAttr(BridgedStringRef attrName) const;
555558
BRIDGED_INLINE bool hasUnsafeNonEscapableResult() const;
556559
BRIDGED_INLINE bool hasResultDependsOnSelf() const;
560+
bool mayBindDynamicSelf() const;
557561
BRIDGED_INLINE EffectsKind getEffectAttribute() const;
558562
BRIDGED_INLINE PerformanceConstraints getPerformanceConstraints() const;
559563
BRIDGED_INLINE InlineStrategy getInlineStrategy() const;
@@ -566,6 +570,9 @@ struct BridgedFunction {
566570
BRIDGED_INLINE void setIsPerformanceConstraint(bool isPerfConstraint) const;
567571
BRIDGED_INLINE bool isResilientNominalDecl(BridgedNominalTypeDecl decl) const;
568572
BRIDGED_INLINE BridgedType getLoweredType(BridgedASTType type) const;
573+
bool isTrapNoReturn() const;
574+
bool isAutodiffVJP() const;
575+
SwiftInt specializationLevel() const;
569576

570577
enum class ParseEffectsMode {
571578
argumentEffectsFromSource,
@@ -658,6 +665,7 @@ struct BridgedSubstitutionMap {
658665

659666
BRIDGED_INLINE BridgedSubstitutionMap();
660667
BRIDGED_INLINE bool isEmpty() const;
668+
BRIDGED_INLINE bool hasAnySubstitutableParams() const;
661669
};
662670

663671
struct BridgedTypeArray {
@@ -917,6 +925,7 @@ struct BridgedInstruction {
917925
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedSubstitutionMap ApplySite_getSubstitutionMap() const;
918926
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedASTType ApplySite_getSubstitutedCalleeType() const;
919927
BRIDGED_INLINE SwiftInt ApplySite_getNumArguments() const;
928+
BRIDGED_INLINE bool ApplySite_isCalleeNoReturn() const;
920929
BRIDGED_INLINE SwiftInt FullApplySite_numIndirectResultArguments() const;
921930

922931
// =========================================================================//

include/swift/SIL/SILBridgingImpl.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,10 @@ bool BridgedType::containsNoEscapeFunction() const {
262262
return unbridged().containsNoEscapeFunction();
263263
}
264264

265+
bool BridgedType::isThickFunction() const {
266+
return unbridged().isThickFunction();
267+
}
268+
265269
bool BridgedType::isAsyncFunction() const {
266270
return unbridged().isAsyncFunction();
267271
}
@@ -530,6 +534,10 @@ bool BridgedSubstitutionMap::isEmpty() const {
530534
return unbridged().empty();
531535
}
532536

537+
bool BridgedSubstitutionMap::hasAnySubstitutableParams() const {
538+
return unbridged().hasAnySubstitutableParams();
539+
}
540+
533541
//===----------------------------------------------------------------------===//
534542
// BridgedLocation
535543
//===----------------------------------------------------------------------===//
@@ -635,6 +643,10 @@ bool BridgedFunction::isAsync() const {
635643
return getFunction()->isAsync();
636644
}
637645

646+
bool BridgedFunction::isReabstractionThunk() const {
647+
return getFunction()->isThunk() == swift::IsReabstractionThunk;
648+
}
649+
638650
bool BridgedFunction::isGlobalInitFunction() const {
639651
return getFunction()->isGlobalInit();
640652
}
@@ -1283,6 +1295,10 @@ SwiftInt BridgedInstruction::ApplySite_getNumArguments() const {
12831295
return swift::ApplySite(unbridged()).getNumArguments();
12841296
}
12851297

1298+
bool BridgedInstruction::ApplySite_isCalleeNoReturn() const {
1299+
return swift::ApplySite(unbridged()).isCalleeNoReturn();
1300+
}
1301+
12861302
SwiftInt BridgedInstruction::FullApplySite_numIndirectResultArguments() const {
12871303
auto fas = swift::FullApplySite(unbridged());
12881304
return fas.getNumIndirectSILResults();

include/swift/SIL/SILType.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,13 @@ class SILType {
555555
// Handle whatever AST types are known to hold functions. Namely tuples.
556556
return ty->isNoEscape();
557557
}
558+
559+
bool isThickFunction() const {
560+
if (auto *fTy = getASTType()->getAs<SILFunctionType>()) {
561+
return fTy->getRepresentation() == SILFunctionType::Representation::Thick;
562+
}
563+
return false;
564+
}
558565

559566
bool isAsyncFunction() const {
560567
if (auto *fTy = getASTType()->getAs<SILFunctionType>()) {
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
//===-------------------------- ClosureSpecializer.h ------------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===-----------------------------------------------------------------------------===//
12+
#ifndef SWIFT_SILOPTIMIZER_CLOSURESPECIALIZER_H
13+
#define SWIFT_SILOPTIMIZER_CLOSURESPECIALIZER_H
14+
15+
#include "swift/SIL/SILFunction.h"
16+
17+
namespace swift {
18+
19+
/// If \p function is a function-signature specialization for a constant-
20+
/// propagated function argument, returns 1.
21+
/// If \p function is a specialization of such a specialization, returns 2.
22+
/// And so on.
23+
int getSpecializationLevel(SILFunction *f);
24+
25+
enum class AutoDiffFunctionComponent : char { JVP = 'f', VJP = 'r' };
26+
27+
/// Returns true if the function is the JVP or the VJP corresponding to
28+
/// a differentiable function.
29+
bool isDifferentiableFuncComponent(
30+
SILFunction *f,
31+
AutoDiffFunctionComponent component = AutoDiffFunctionComponent::VJP);
32+
33+
} // namespace swift
34+
#endif

include/swift/SILOptimizer/OptimizerBridging.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,9 @@ struct BridgedPassContext {
239239
SWIFT_IMPORT_UNSAFE BridgedOwnedString mangleWithDeadArgs(const SwiftInt * _Nullable deadArgs,
240240
SwiftInt numDeadArgs,
241241
BridgedFunction function) const;
242+
SWIFT_IMPORT_UNSAFE BridgedOwnedString mangleWithClosureArgs(BridgedValueArray closureArgs,
243+
BridgedArrayRef closureArgIndices,
244+
BridgedFunction applySiteCallee) const;
242245

243246
SWIFT_IMPORT_UNSAFE BridgedGlobalVar createGlobalVariable(BridgedStringRef name, BridgedType type,
244247
bool isPrivate) const;

include/swift/SILOptimizer/OptimizerBridgingImpl.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919
#ifndef SWIFT_SILOPTIMIZER_OPTIMIZERBRIDGING_IMPL_H
2020
#define SWIFT_SILOPTIMIZER_OPTIMIZERBRIDGING_IMPL_H
2121

22-
#include "swift/SILOptimizer/OptimizerBridging.h"
22+
#include "swift/Demangling/Demangle.h"
2323
#include "swift/SILOptimizer/Analysis/AliasAnalysis.h"
2424
#include "swift/SILOptimizer/Analysis/BasicCalleeAnalysis.h"
2525
#include "swift/SILOptimizer/Analysis/DeadEndBlocksAnalysis.h"
2626
#include "swift/SILOptimizer/Analysis/DominanceAnalysis.h"
27+
#include "swift/SILOptimizer/OptimizerBridging.h"
2728
#include "swift/SILOptimizer/PassManager/PassManager.h"
2829
#include "swift/SILOptimizer/Utils/InstOptUtils.h"
2930

include/swift/SILOptimizer/Utils/CFGOptUtils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ bool mergeBasicBlockWithSuccessor(SILBasicBlock *bb, DominanceInfo *domInfo,
190190
/// quadratic.
191191
bool mergeBasicBlocks(SILFunction *f);
192192

193+
bool isTrapNoReturnFunction(SILFunction *f);
194+
193195
/// Return true if we conservatively find all bb's that are non-failure exit
194196
/// basic blocks and place them in \p bbs. If we find something we don't
195197
/// understand, bail.

lib/SILOptimizer/IPO/ClosureSpecializer.cpp

Lines changed: 97 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@
5656
//===----------------------------------------------------------------------===//
5757

5858
#define DEBUG_TYPE "closure-specialization"
59+
#include "swift/SILOptimizer/IPO/ClosureSpecializer.h"
5960
#include "swift/Basic/Range.h"
61+
#include "swift/Demangling/Demangle.h"
6062
#include "swift/Demangling/Demangler.h"
6163
#include "swift/SIL/InstructionUtils.h"
6264
#include "swift/SIL/SILCloner.h"
@@ -103,6 +105,101 @@ static bool isSupportedClosureKind(const SILInstruction *I) {
103105
return isa<ThinToThickFunctionInst>(I) || isa<PartialApplyInst>(I);
104106
}
105107

108+
static const int SpecializationLevelLimit = 2;
109+
110+
static int getSpecializationLevelRecursive(StringRef funcName,
111+
Demangler &parent) {
112+
using namespace Demangle;
113+
114+
Demangler demangler;
115+
demangler.providePreallocatedMemory(parent);
116+
117+
// Check for this kind of node tree:
118+
//
119+
// kind=Global
120+
// kind=FunctionSignatureSpecialization
121+
// kind=SpecializationPassID, index=1
122+
// kind=FunctionSignatureSpecializationParam
123+
// kind=FunctionSignatureSpecializationParamKind, index=5
124+
// kind=FunctionSignatureSpecializationParamPayload, text="..."
125+
//
126+
Node *root = demangler.demangleSymbol(funcName);
127+
if (!root)
128+
return 0;
129+
if (root->getKind() != Node::Kind::Global)
130+
return 0;
131+
Node *funcSpec = root->getFirstChild();
132+
if (!funcSpec || funcSpec->getNumChildren() < 2)
133+
return 0;
134+
if (funcSpec->getKind() != Node::Kind::FunctionSignatureSpecialization)
135+
return 0;
136+
137+
// Match any function specialization. We check for constant propagation at the
138+
// parameter level.
139+
Node *param = funcSpec->getChild(0);
140+
if (param->getKind() != Node::Kind::SpecializationPassID)
141+
return SpecializationLevelLimit + 1; // unrecognized format
142+
143+
unsigned maxParamLevel = 0;
144+
for (unsigned paramIdx = 1; paramIdx < funcSpec->getNumChildren();
145+
++paramIdx) {
146+
Node *param = funcSpec->getChild(paramIdx);
147+
if (param->getKind() != Node::Kind::FunctionSignatureSpecializationParam)
148+
return SpecializationLevelLimit + 1; // unrecognized format
149+
150+
// A parameter is recursive if it has a kind with index and type payload
151+
if (param->getNumChildren() < 2)
152+
continue;
153+
154+
Node *kindNd = param->getChild(0);
155+
if (kindNd->getKind() !=
156+
Node::Kind::FunctionSignatureSpecializationParamKind) {
157+
return SpecializationLevelLimit + 1; // unrecognized format
158+
}
159+
auto kind = FunctionSigSpecializationParamKind(kindNd->getIndex());
160+
if (kind != FunctionSigSpecializationParamKind::ConstantPropFunction)
161+
continue;
162+
Node *payload = param->getChild(1);
163+
if (payload->getKind() !=
164+
Node::Kind::FunctionSignatureSpecializationParamPayload) {
165+
return SpecializationLevelLimit + 1; // unrecognized format
166+
}
167+
// Check if the specialized function is a specialization itself.
168+
unsigned paramLevel =
169+
1 + getSpecializationLevelRecursive(payload->getText(), demangler);
170+
if (paramLevel > maxParamLevel)
171+
maxParamLevel = paramLevel;
172+
}
173+
return maxParamLevel;
174+
}
175+
176+
//===----------------------------------------------------------------------===//
177+
// Publicly visible for bridging
178+
//===----------------------------------------------------------------------===//
179+
180+
int swift::getSpecializationLevel(SILFunction *f) {
181+
Demangle::StackAllocatedDemangler<1024> demangler;
182+
return getSpecializationLevelRecursive(f->getName(), demangler);
183+
}
184+
185+
bool swift::isDifferentiableFuncComponent(
186+
SILFunction *f, AutoDiffFunctionComponent expectedComponent) {
187+
Demangle::Context Ctx;
188+
if (auto *root = Ctx.demangleSymbolAsNode(f->getName())) {
189+
if (auto *node =
190+
root->findByKind(Demangle::Node::Kind::AutoDiffFunctionKind, 3)) {
191+
if (node->hasIndex()) {
192+
auto component = (char)node->getIndex();
193+
if (component == (char)expectedComponent) {
194+
return true;
195+
}
196+
}
197+
}
198+
}
199+
200+
return false;
201+
}
202+
106203
//===----------------------------------------------------------------------===//
107204
// Closure Spec Cloner Interface
108205
//===----------------------------------------------------------------------===//
@@ -1084,82 +1181,6 @@ static bool canSpecializeFullApplySite(FullApplySiteKind kind) {
10841181
llvm_unreachable("covered switch");
10851182
}
10861183

1087-
const int SpecializationLevelLimit = 2;
1088-
1089-
static int getSpecializationLevelRecursive(StringRef funcName, Demangler &parent) {
1090-
using namespace Demangle;
1091-
1092-
Demangler demangler;
1093-
demangler.providePreallocatedMemory(parent);
1094-
1095-
// Check for this kind of node tree:
1096-
//
1097-
// kind=Global
1098-
// kind=FunctionSignatureSpecialization
1099-
// kind=SpecializationPassID, index=1
1100-
// kind=FunctionSignatureSpecializationParam
1101-
// kind=FunctionSignatureSpecializationParamKind, index=5
1102-
// kind=FunctionSignatureSpecializationParamPayload, text="..."
1103-
//
1104-
Node *root = demangler.demangleSymbol(funcName);
1105-
if (!root)
1106-
return 0;
1107-
if (root->getKind() != Node::Kind::Global)
1108-
return 0;
1109-
Node *funcSpec = root->getFirstChild();
1110-
if (!funcSpec || funcSpec->getNumChildren() < 2)
1111-
return 0;
1112-
if (funcSpec->getKind() != Node::Kind::FunctionSignatureSpecialization)
1113-
return 0;
1114-
1115-
// Match any function specialization. We check for constant propagation at the
1116-
// parameter level.
1117-
Node *param = funcSpec->getChild(0);
1118-
if (param->getKind() != Node::Kind::SpecializationPassID)
1119-
return SpecializationLevelLimit + 1; // unrecognized format
1120-
1121-
unsigned maxParamLevel = 0;
1122-
for (unsigned paramIdx = 1; paramIdx < funcSpec->getNumChildren();
1123-
++paramIdx) {
1124-
Node *param = funcSpec->getChild(paramIdx);
1125-
if (param->getKind() != Node::Kind::FunctionSignatureSpecializationParam)
1126-
return SpecializationLevelLimit + 1; // unrecognized format
1127-
1128-
// A parameter is recursive if it has a kind with index and type payload
1129-
if (param->getNumChildren() < 2)
1130-
continue;
1131-
1132-
Node *kindNd = param->getChild(0);
1133-
if (kindNd->getKind()
1134-
!= Node::Kind::FunctionSignatureSpecializationParamKind) {
1135-
return SpecializationLevelLimit + 1; // unrecognized format
1136-
}
1137-
auto kind = FunctionSigSpecializationParamKind(kindNd->getIndex());
1138-
if (kind != FunctionSigSpecializationParamKind::ConstantPropFunction)
1139-
continue;
1140-
Node *payload = param->getChild(1);
1141-
if (payload->getKind()
1142-
!= Node::Kind::FunctionSignatureSpecializationParamPayload) {
1143-
return SpecializationLevelLimit + 1; // unrecognized format
1144-
}
1145-
// Check if the specialized function is a specialization itself.
1146-
unsigned paramLevel =
1147-
1 + getSpecializationLevelRecursive(payload->getText(), demangler);
1148-
if (paramLevel > maxParamLevel)
1149-
maxParamLevel = paramLevel;
1150-
}
1151-
return maxParamLevel;
1152-
}
1153-
1154-
/// If \p function is a function-signature specialization for a constant-
1155-
/// propagated function argument, returns 1.
1156-
/// If \p function is a specialization of such a specialization, returns 2.
1157-
/// And so on.
1158-
static int getSpecializationLevel(SILFunction *f) {
1159-
Demangle::StackAllocatedDemangler<1024> demangler;
1160-
return getSpecializationLevelRecursive(f->getName(), demangler);
1161-
}
1162-
11631184
bool SILClosureSpecializerTransform::gatherCallSites(
11641185
SILFunction *Caller,
11651186
llvm::SmallVectorImpl<std::unique_ptr<ClosureInfo>> &ClosureCandidates,

0 commit comments

Comments
 (0)