Skip to content

Commit c8375c0

Browse files
committed
[Autodiff] Adds part of the closure-specialization optimization pass
Changes in this CR add part of the, Swift based, Autodiff specific closure specialization optimization pass. The pass does not modify any code nor does it even exist in any of the optimization pipelines. The rationale for pushing this partially complete optimization pass upstream is to keep up with the breaking changes in the underlying Swift based compiler infrastructure.
1 parent 003c908 commit c8375c0

File tree

7 files changed

+985
-1
lines changed

7 files changed

+985
-1
lines changed

SwiftCompilerSources/Sources/Optimizer/FunctionPasses/AutodiffClosureSpecialization.swift

Lines changed: 689 additions & 0 deletions
Large diffs are not rendered by default.

SwiftCompilerSources/Sources/Optimizer/FunctionPasses/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,5 @@ swift_compiler_sources(Optimizer
3030
SimplificationPasses.swift
3131
StackPromotion.swift
3232
StripObjectHeaders.swift
33+
AutodiffClosureSpecialization.swift
3334
)

SwiftCompilerSources/Sources/Optimizer/PassManager/PassRegistration.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ private func registerSwiftPasses() {
9393
registerPass(lifetimeDependenceDiagnosticsPass, { lifetimeDependenceDiagnosticsPass.run($0) })
9494
registerPass(lifetimeDependenceInsertionPass, { lifetimeDependenceInsertionPass.run($0) })
9595
registerPass(lifetimeDependenceScopeFixupPass, { lifetimeDependenceScopeFixupPass.run($0) })
96+
registerPass(autodiffClosureSpecialization, { autodiffClosureSpecialization.run($0) })
97+
9698
// Instruction passes
9799
registerForSILCombine(BeginCOWMutationInst.self, { run(BeginCOWMutationInst.self, $0) })
98100
registerForSILCombine(GlobalValueInst.self, { run(GlobalValueInst.self, $0) })

SwiftCompilerSources/Sources/Optimizer/Utilities/Test.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@ public func registerOptimizerTests() {
164164
lifetimeDependenceUseTest,
165165
linearLivenessTest,
166166
parseTestSpecificationTest,
167-
variableIntroducerTest
167+
variableIntroducerTest,
168+
gatherCallSitesTest
168169
)
169170

170171
// Finally register the thunk they all call through.

include/swift/SILOptimizer/PassManager/Passes.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ PASS(CapturePropagation, "capture-prop",
146146
"Captured Constant Propagation")
147147
PASS(ClosureSpecializer, "closure-specialize",
148148
"Closure Specialization on Constant Function Arguments")
149+
SWIFT_FUNCTION_PASS(AutodiffClosureSpecialization, "autodiff-closure-specialize",
150+
"Autodiff specific closure-specialization pass")
149151
PASS(ClosureLifetimeFixup, "closure-lifetime-fixup",
150152
"Closure Lifetime Fixup")
151153
PASS(CodeSinking, "code-sinking",
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
// RUN: %target-sil-opt -test-runner %s -o /dev/null 2>&1 | %FileCheck %s
2+
3+
// REQUIRES: swift_in_compiler
4+
5+
sil_stage canonical
6+
7+
import Builtin
8+
import Swift
9+
import SwiftShims
10+
11+
import _Differentiation
12+
13+
// ===================== Gathering callsites and corresponding closures ===================== //
14+
15+
//////////////////////////////
16+
// Single closure call site //
17+
//////////////////////////////
18+
sil @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float)
19+
20+
sil private @$pullback_f : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float {
21+
bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> (Float, Float)):
22+
%2 = apply %1(%0) : $@callee_guaranteed (Float) -> (Float, Float) // users: %5, %4
23+
strong_release %1 : $@callee_guaranteed (Float) -> (Float, Float) // id: %3
24+
%4 = tuple_extract %2 : $(Float, Float), 0 // user: %7
25+
%5 = tuple_extract %2 : $(Float, Float), 1 // user: %6
26+
%6 = struct_extract %5 : $Float, #Float._value // user: %8
27+
%7 = struct_extract %4 : $Float, #Float._value // user: %8
28+
%8 = builtin "fadd_FPIEEE32"(%6 : $Builtin.FPIEEE32, %7 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %9
29+
%9 = struct $Float (%8 : $Builtin.FPIEEE32) // users: %11, %10
30+
debug_value %9 : $Float, let, name "x", argno 1 // id: %10
31+
return %9 : $Float // id: %11
32+
}
33+
34+
// reverse-mode derivative of f(_:)
35+
sil hidden @$s4test1fyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
36+
bb0(%0 : $Float):
37+
specify_test "closure_specialize_gather_call_sites"
38+
// CHECK-LABEL: Specializing closures in function: $s4test1fyS2fFTJrSpSr
39+
// CHECK: PartialApply call site: %8 = partial_apply [callee_guaranteed] %7(%6) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float // user: %9
40+
// CHECK: Passed in closures:
41+
// CHECK: 1. %6 = partial_apply [callee_guaranteed] %5(%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %8
42+
43+
debug_value %0 : $Float, let, name "x", argno 1 // id: %1
44+
%2 = struct_extract %0 : $Float, #Float._value // users: %3, %3
45+
%3 = builtin "fmul_FPIEEE32"(%2 : $Builtin.FPIEEE32, %2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %4
46+
%4 = struct $Float (%3 : $Builtin.FPIEEE32) // user: %9
47+
// function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:)
48+
%5 = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %6
49+
%6 = partial_apply [callee_guaranteed] %5(%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %8
50+
// function_ref pullback of f(_:)
51+
%7 = function_ref @$pullback_f : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float // user: %8
52+
%8 = partial_apply [callee_guaranteed] %7(%6) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float // user: %9
53+
%9 = tuple (%4 : $Float, %8 : $@callee_guaranteed (Float) -> Float) // user: %10
54+
return %9 : $(Float, @callee_guaranteed (Float) -> Float) // id: %10
55+
}
56+
57+
///////////////////////////////
58+
// Multiple closure callsite //
59+
///////////////////////////////
60+
sil @$_vjpSin : $@convention(thin) (Float, Float) -> Float // user: %6
61+
sil @$_vjpCos : $@convention(thin) (Float, Float) -> Float // user: %10
62+
sil @$_vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float)
63+
64+
// pullback of g(_:)
65+
sil private @$pullback_g : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float {
66+
bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> Float, %2 : $@callee_guaranteed (Float) -> Float, %3 : $@callee_guaranteed (Float) -> (Float, Float)):
67+
%4 = apply %3(%0) : $@callee_guaranteed (Float) -> (Float, Float) // users: %7, %6
68+
strong_release %3 : $@callee_guaranteed (Float) -> (Float, Float) // id: %5
69+
%6 = tuple_extract %4 : $(Float, Float), 0 // user: %10
70+
%7 = tuple_extract %4 : $(Float, Float), 1 // user: %8
71+
%8 = apply %2(%7) : $@callee_guaranteed (Float) -> Float // user: %12
72+
strong_release %2 : $@callee_guaranteed (Float) -> Float // id: %9
73+
%10 = apply %1(%6) : $@callee_guaranteed (Float) -> Float // user: %13
74+
strong_release %1 : $@callee_guaranteed (Float) -> Float // id: %11
75+
%12 = struct_extract %8 : $Float, #Float._value // user: %14
76+
%13 = struct_extract %10 : $Float, #Float._value // user: %14
77+
%14 = builtin "fadd_FPIEEE32"(%13 : $Builtin.FPIEEE32, %12 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %15
78+
%15 = struct $Float (%14 : $Builtin.FPIEEE32) // users: %17, %16
79+
debug_value %15 : $Float, let, name "x", argno 1 // id: %16
80+
return %15 : $Float // id: %17
81+
}
82+
83+
// reverse-mode derivative of g(_:)
84+
sil hidden @$s4test1gyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
85+
bb0(%0 : $Float):
86+
specify_test "closure_specialize_gather_call_sites"
87+
// CHECK-LABEL: Specializing closures in function: $s4test1gyS2fFTJrSpSr
88+
// CHECK: PartialApply call site: %16 = partial_apply [callee_guaranteed] %15(%6, %10, %14) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float // user: %17
89+
// CHECK: Passed in closures:
90+
// CHECK: 1. %6 = partial_apply [callee_guaranteed] %5(%0) : $@convention(thin) (Float, Float) -> Float // user: %16
91+
// CHECK: 2. %10 = partial_apply [callee_guaranteed] %9(%0) : $@convention(thin) (Float, Float) -> Float // user: %16
92+
// CHECK: 3. %14 = partial_apply [callee_guaranteed] %13(%8, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %16
93+
94+
debug_value %0 : $Float, let, name "x", argno 1 // id: %1
95+
%2 = struct_extract %0 : $Float, #Float._value // users: %7, %3
96+
%3 = builtin "int_sin_FPIEEE32"(%2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // users: %11, %4
97+
%4 = struct $Float (%3 : $Builtin.FPIEEE32) // user: %14
98+
// function_ref closure #1 in _vjpSin(_:)
99+
%5 = function_ref @$_vjpSin : $@convention(thin) (Float, Float) -> Float // user: %6
100+
%6 = partial_apply [callee_guaranteed] %5(%0) : $@convention(thin) (Float, Float) -> Float // user: %16
101+
%7 = builtin "int_cos_FPIEEE32"(%2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // users: %11, %8
102+
%8 = struct $Float (%7 : $Builtin.FPIEEE32) // user: %14
103+
// function_ref closure #1 in _vjpCos(_:)
104+
%9 = function_ref @$_vjpCos : $@convention(thin) (Float, Float) -> Float // user: %10
105+
%10 = partial_apply [callee_guaranteed] %9(%0) : $@convention(thin) (Float, Float) -> Float // user: %16
106+
%11 = builtin "fmul_FPIEEE32"(%3 : $Builtin.FPIEEE32, %7 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %12
107+
%12 = struct $Float (%11 : $Builtin.FPIEEE32) // user: %17
108+
// function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:)
109+
%13 = function_ref @$_vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %14
110+
%14 = partial_apply [callee_guaranteed] %13(%8, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %16
111+
// function_ref pullback of g(_:)
112+
%15 = function_ref @$pullback_g : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float // user: %16
113+
%16 = partial_apply [callee_guaranteed] %15(%6, %10, %14) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float // user: %17
114+
%17 = tuple (%12 : $Float, %16 : $@callee_guaranteed (Float) -> Float) // user: %18
115+
return %17 : $(Float, @callee_guaranteed (Float) -> Float) // id: %18
116+
}

0 commit comments

Comments
 (0)