Skip to content

Commit 64da348

Browse files
authored
Merge pull request #73688 from jkshtj/main
[Autodiff] Adds logic to rewrite call-sites using functions specialized by the closure-spec optimization
2 parents 3ae93a9 + 487648a commit 64da348

File tree

14 files changed

+779
-215
lines changed

14 files changed

+779
-215
lines changed

SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift

Lines changed: 266 additions & 53 deletions
Large diffs are not rendered by default.

SwiftCompilerSources/Sources/Optimizer/PassManager/Context.swift

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ extension Context {
6060
_bridged.lookupFunction($0).function
6161
}
6262
}
63+
64+
func notifyNewFunction(function: Function, derivedFrom: Function) {
65+
_bridged.addFunctionToPassManagerWorklist(function.bridged, derivedFrom.bridged)
66+
}
6367
}
6468

6569
/// A context which allows mutation of a function's SIL.
@@ -357,7 +361,7 @@ struct FunctionPassContext : MutatingContext {
357361
return String(taking: _bridged.mangleOutlinedVariable(function.bridged))
358362
}
359363

360-
func mangle(withClosureArgs closureArgs: [Value], closureArgIndices: [Int], from applySiteCallee: Function) -> String {
364+
func mangle(withClosureArguments closureArgs: [Value], closureArgIndices: [Int], from applySiteCallee: Function) -> String {
361365
closureArgs.withBridgedValues { bridgedClosureArgsRef in
362366
closureArgIndices.withBridgedArrayRef{bridgedClosureArgIndicesRef in
363367
String(taking: _bridged.mangleWithClosureArgs(
@@ -392,13 +396,13 @@ struct FunctionPassContext : MutatingContext {
392396
}
393397
}
394398

395-
func buildSpecializedFunction(specializedFunction: Function, buildFn: (Function, FunctionPassContext) -> ()) {
399+
func buildSpecializedFunction<T>(specializedFunction: Function, buildFn: (Function, FunctionPassContext) -> T) -> T {
396400
let nestedFunctionPassContext =
397401
FunctionPassContext(_bridged: _bridged.initializeNestedPassContext(specializedFunction.bridged))
398402

399403
defer { _bridged.deinitializedNestedPassContext() }
400404

401-
buildFn(specializedFunction, nestedFunctionPassContext)
405+
return buildFn(specializedFunction, nestedFunctionPassContext)
402406
}
403407
}
404408

SwiftCompilerSources/Sources/Optimizer/Utilities/Test.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ public func registerOptimizerTests() {
167167
parseTestSpecificationTest,
168168
variableIntroducerTest,
169169
gatherCallSitesTest,
170-
specializedFunctionSignatureAndBodyTest
170+
specializedFunctionSignatureAndBodyTest,
171+
rewrittenCallerBodyTest
171172
)
172173

173174
// Finally register the thunk they all call through.

SwiftCompilerSources/Sources/SIL/Builder.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,18 @@ public struct Builder {
151151
return notifyNew(endInit.getAs(EndInitLetRefInst.self))
152152
}
153153

154+
@discardableResult
155+
public func createRetainValue(operand: Value) -> RetainValueInst {
156+
let retain = bridged.createRetainValue(operand.bridged)
157+
return notifyNew(retain.getAs(RetainValueInst.self))
158+
}
159+
160+
@discardableResult
161+
public func createReleaseValue(operand: Value) -> ReleaseValueInst {
162+
let release = bridged.createReleaseValue(operand.bridged)
163+
return notifyNew(release.getAs(ReleaseValueInst.self))
164+
}
165+
154166
@discardableResult
155167
public func createStrongRetain(operand: Value) -> StrongRetainInst {
156168
let retain = bridged.createStrongRetain(operand.bridged)

SwiftCompilerSources/Sources/SIL/Operand.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ public struct OperandArray : RandomAccessCollection, CustomReflectable {
6262
self.count = count
6363
}
6464

65+
static public var empty: OperandArray {
66+
OperandArray(base: OptionalBridgedOperand(bridged: nil), count: 0)
67+
}
68+
6569
public var startIndex: Int { return 0 }
6670
public var endIndex: Int { return count }
6771

include/swift/SIL/SILBridging.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,6 +1210,10 @@ struct BridgedBuilder{
12101210
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction createBeginDeallocRef(BridgedValue reference,
12111211
BridgedValue allocation) const;
12121212
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction createEndInitLetRef(BridgedValue op) const;
1213+
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction
1214+
createRetainValue(BridgedValue op) const;
1215+
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction
1216+
createReleaseValue(BridgedValue op) const;
12131217
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction createStrongRetain(BridgedValue op) const;
12141218
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction createStrongRelease(BridgedValue op) const;
12151219
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction createUnownedRetain(BridgedValue op) const;

include/swift/SIL/SILBridgingImpl.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#ifndef SWIFT_SIL_SILBRIDGING_IMPL_H
2020
#define SWIFT_SIL_SILBRIDGING_IMPL_H
2121

22+
#include "SILBridging.h"
2223
#include "swift/AST/Builtins.h"
2324
#include "swift/AST/Decl.h"
2425
#include "swift/AST/SubstitutionMap.h"
@@ -1589,6 +1590,18 @@ BridgedInstruction BridgedBuilder::createEndInitLetRef(BridgedValue op) const {
15891590
return {unbridged().createEndInitLetRef(regularLoc(), op.getSILValue())};
15901591
}
15911592

1593+
BridgedInstruction BridgedBuilder::createRetainValue(BridgedValue op) const {
1594+
auto b = unbridged();
1595+
return {b.createRetainValue(regularLoc(), op.getSILValue(),
1596+
b.getDefaultAtomicity())};
1597+
}
1598+
1599+
BridgedInstruction BridgedBuilder::createReleaseValue(BridgedValue op) const {
1600+
auto b = unbridged();
1601+
return {b.createReleaseValue(regularLoc(), op.getSILValue(),
1602+
b.getDefaultAtomicity())};
1603+
}
1604+
15921605
BridgedInstruction BridgedBuilder::createStrongRetain(BridgedValue op) const {
15931606
auto b = unbridged();
15941607
return {b.createStrongRetain(regularLoc(), op.getSILValue(), b.getDefaultAtomicity())};

include/swift/SILOptimizer/OptimizerBridging.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,9 @@ struct BridgedPassContext {
337337
BRIDGED_INLINE bool continueWithNextSubpassRun(OptionalBridgedInstruction inst) const;
338338
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedPassContext initializeNestedPassContext(BridgedFunction newFunction) const;
339339
BRIDGED_INLINE void deinitializedNestedPassContext() const;
340+
BRIDGED_INLINE void
341+
addFunctionToPassManagerWorklist(BridgedFunction newFunction,
342+
BridgedFunction oldFunction) const;
340343

341344
// SSAUpdater
342345

include/swift/SILOptimizer/OptimizerBridgingImpl.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,13 @@ void BridgedPassContext::SSAUpdater_initialize(
454454
BridgedValue::castToOwnership(ownership));
455455
}
456456

457+
void BridgedPassContext::addFunctionToPassManagerWorklist(
458+
BridgedFunction newFunction, BridgedFunction oldFunction) const {
459+
swift::SILPassManager *pm = invocation->getPassManager();
460+
pm->addFunctionToWorklist(newFunction.getFunction(),
461+
oldFunction.getFunction());
462+
}
463+
457464
void BridgedPassContext::SSAUpdater_addAvailableValue(BridgedBasicBlock block, BridgedValue value) const {
458465
invocation->getSSAUpdater()->addAvailableValue(block.unbridged(),
459466
value.getSILValue());

lib/SILOptimizer/PassManager/PassPipeline.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,10 +1008,14 @@ SILPassPipelinePlan::getPerformancePassPipeline(const SILOptions &Options) {
10081008
if (Options.StopOptimizationAfterSerialization)
10091009
return P;
10101010

1011+
P.addAutodiffClosureSpecialization();
1012+
10111013
// After serialization run the function pass pipeline to iteratively lower
10121014
// high-level constructs like @_semantics calls.
10131015
addMidLevelFunctionPipeline(P);
10141016

1017+
P.addAutodiffClosureSpecialization();
1018+
10151019
// Perform optimizations that specialize.
10161020
addClosureSpecializePassPipeline(P);
10171021

0 commit comments

Comments
 (0)