Skip to content

Commit 5528cf1

Browse files
authored
[AutoDiff] Run AutoDiff closure spec pass for all VJPs (#81548)
Previously, AutoDiff closure specialization pass was triggered only on VJPs containing single basic block. However, the pass logic allows running on arbitrary VJPs. This PR enables the pass for all VJPs unconditionally. So, if the pullback corresponding to multiple-BB VJP accepts some closures directly as arguments, these closures might become specialized by the pass. Closures passed via payload of branch tracing enum are not specialized - this is subject for future changes. The PR contains several commits. 1. The thing named "call site" in the code is partial_apply of pullback corresponding to the VJP. This might appear only once, so we drop support for multiple "call sites". 2. Enhance existing SILOptimizer tests for the pass. 3. Add validation-tests for single basic block case. 4. The change itself - delete check against single basic block. 5. Add validation-tests for multiple basic block case. 6. Add SILOptimizer tests for multiple basic block case.
1 parent 36b5090 commit 5528cf1

File tree

11 files changed

+2219
-659
lines changed

11 files changed

+2219
-659
lines changed

SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift

Lines changed: 132 additions & 149 deletions
Large diffs are not rendered by default.

SwiftCompilerSources/Sources/Optimizer/Utilities/Test.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ public func registerOptimizerTests() {
160160
enclosingValuesTest,
161161
forwardingDefUseTest,
162162
forwardingUseDefTest,
163-
gatherCallSitesTest,
163+
getPullbackClosureInfoTest,
164164
interiorLivenessTest,
165165
lifetimeDependenceRootTest,
166166
lifetimeDependenceScopeTest,

test/AutoDiff/SILOptimizer/closure_specialization.sil

Lines changed: 0 additions & 509 deletions
This file was deleted.

test/AutoDiff/SILOptimizer/closure_specialization/multi_bb_bte.sil

Lines changed: 689 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
/// Multi basic block VJP, pullback not accepting branch tracing enum argument.
2+
3+
// RUN: %target-sil-opt -sil-print-types -test-runner %s -o /dev/null 2>&1 | %FileCheck %s --check-prefixes=TRUNNER,CHECK
4+
// RUN: %target-sil-opt -sil-print-types -autodiff-closure-specialization -sil-combine %s -o - | %FileCheck %s --check-prefixes=COMBINE,CHECK
5+
6+
// REQUIRES: swift_in_compiler
7+
8+
sil_stage canonical
9+
10+
import Builtin
11+
import Swift
12+
import SwiftShims
13+
14+
import _Differentiation
15+
16+
/// This SIL corresponds to the following Swift:
17+
///
18+
/// struct Class: Differentiable {
19+
/// var stored: Float
20+
/// var optional: Float?
21+
///
22+
/// init(stored: Float, optional: Float?) {
23+
/// self.stored = stored
24+
/// self.optional = optional
25+
/// }
26+
///
27+
/// @differentiable(reverse)
28+
/// func method() -> Float {
29+
/// let c: Class
30+
/// do {
31+
/// let tmp = Class(stored: 1 * stored, optional: optional)
32+
/// let tuple = (tmp, tmp)
33+
/// c = tuple.0
34+
/// }
35+
/// if let x = c.optional {
36+
/// return x * c.stored
37+
/// }
38+
/// return 1 * c.stored
39+
/// }
40+
/// }
41+
///
42+
/// @differentiable(reverse)
43+
/// func methodWrapper(_ x: Class) -> Float {
44+
/// x.method()
45+
/// }
46+
47+
struct Class : Differentiable {
48+
@_hasStorage var stored: Float { get set }
49+
@_hasStorage @_hasInitialValue var optional: Float? { get set }
50+
init(stored: Float, optional: Float?)
51+
@differentiable(reverse, wrt: self)
52+
func method() -> Float
53+
struct TangentVector : AdditiveArithmetic, Differentiable {
54+
@_hasStorage var stored: Float { get set }
55+
@_hasStorage var optional: Optional<Float>.TangentVector { get set }
56+
static func + (lhs: Class.TangentVector, rhs: Class.TangentVector) -> Class.TangentVector
57+
static func - (lhs: Class.TangentVector, rhs: Class.TangentVector) -> Class.TangentVector
58+
typealias TangentVector = Class.TangentVector
59+
@_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: Class.TangentVector, _ b: Class.TangentVector) -> Bool
60+
init(stored: Float, optional: Optional<Float>.TangentVector)
61+
static var zero: Class.TangentVector { get }
62+
}
63+
mutating func move(by offset: Class.TangentVector)
64+
}
65+
66+
enum _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0 {
67+
case bb0(((Float) -> Float, (Class.TangentVector) -> (Float, Optional<Float>.TangentVector)))
68+
}
69+
70+
enum _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0 {
71+
case bb0(((Float) -> Float, (Class.TangentVector) -> (Float, Optional<Float>.TangentVector)))
72+
}
73+
74+
enum _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0 {
75+
case bb2((predecessor: _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, (Float) -> Float))
76+
case bb1((predecessor: _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, (Float) -> (Float, Float)))
77+
}
78+
79+
sil @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float)
80+
sil [transparent] [thunk] @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
81+
sil @$s4test5ClassV6stored8optionalACSf_SfSgtcfCTJpSSUpSr : $@convention(thin) (Class.TangentVector) -> (Float, Optional<Float>.TangentVector)
82+
sil @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
83+
84+
// pullback of methodWrapper(_:)
85+
sil private [signature_optimized_thunk] [always_inline] @$s4test13methodWrapperySfAA5ClassVFTJpSpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) -> Class.TangentVector {
86+
bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> Class.TangentVector):
87+
%2 = apply %1(%0) : $@callee_guaranteed (Float) -> Class.TangentVector
88+
strong_release %1
89+
return %2
90+
} // end sil function '$s4test13methodWrapperySfAA5ClassVFTJpSpSr'
91+
92+
// reverse-mode derivative of methodWrapper(_:)
93+
sil hidden @$s4test13methodWrapperySfAA5ClassVFTJrSpSr : $@convention(thin) (Class) -> (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) {
94+
bb0(%0 : $Class):
95+
//=========== Test callsite and closure gathering logic ===========//
96+
specify_test "autodiff_closure_specialize_get_pullback_closure_info"
97+
// TRUNNER-LABEL: Specializing closures in function: $s4test13methodWrapperySfAA5ClassVFTJrSpSr
98+
// TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#A36:]]) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) -> Class.TangentVector
99+
// TRUNNER-NEXT: Passed in closures:
100+
// TRUNNER-NEXT: 1. %[[#A36]] = partial_apply [callee_guaranteed] %[[#]](%[[#]]) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
101+
// TRUNNER-EMPTY:
102+
103+
//=========== Test specialized function signature and body ===========//
104+
specify_test "autodiff_closure_specialize_specialized_function_signature_and_body"
105+
// TRUNNER-LABEL: Generated specialized function: $s4test13methodWrapperySfAA5ClassVFTJpSpSr08$s4test5D19V6methodSfyFTJpSpSr4main05_AD__edfG24F_bb3__Pred__src_0_wrt_0OTf1nc_n
106+
// CHECK: sil private [signature_optimized_thunk] [always_inline] @$s4test13methodWrapperySfAA5ClassVFTJpSpSr08$s4test5D19V6methodSfyFTJpSpSr4main05_AD__edfG24F_bb3__Pred__src_0_wrt_0OTf1nc_n : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector {
107+
// CHECK: bb0(%0 : $Float, %1 : $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0):
108+
// CHECK: %[[#B2:]] = function_ref @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
109+
// TRUNNER: %[[#B3:]] = partial_apply [callee_guaranteed] %[[#B2]](%1) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
110+
// TRUNNER: %[[#B4:]] = apply %[[#B3]](%0) : $@callee_guaranteed (Float) -> Class.TangentVector
111+
// COMBINE-NOT: partial_apply
112+
// COMBINE: %[[#B4:]] = apply %[[#B2]](%0, %1) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
113+
// TRUNNER: strong_release %[[#B3]] : $@callee_guaranteed (Float) -> Class.TangentVector
114+
// CHECK: return %[[#B4]]
115+
116+
//=========== Test rewritten body ===========//
117+
specify_test "autodiff_closure_specialize_rewritten_caller_body"
118+
// TRUNNER-LABEL: Rewritten caller body for: $s4test13methodWrapperySfAA5ClassVFTJrSpSr:
119+
// CHECK: sil hidden @$s4test13methodWrapperySfAA5ClassVFTJrSpSr : $@convention(thin) (Class) -> (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) {
120+
// CHECK: bb3(%[[#C33:]] : $Float, %[[#C34:]] : $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0):
121+
// TRUNNER: %[[#C35:]] = function_ref @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
122+
// TRUNNER: %[[#C37:]] = partial_apply [callee_guaranteed] %[[#C35]](%[[#C34]]) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
123+
// TRUNNER: %[[#C38:]] = function_ref @$s4test13methodWrapperySfAA5ClassVFTJpSpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) -> Class.TangentVector
124+
// COMBINE-NOT: function_ref @$s4test5ClassV6methodSfyFTJpSpSr
125+
// COMBINE-NOT: partial_apply
126+
// COMBINE-NOT: function_ref @$s4test13methodWrapperySfAA5ClassVFTJpSpSr
127+
// CHECK: %[[#C39:]] = function_ref @$s4test13methodWrapperySfAA5ClassVFTJpSpSr08$s4test5D19V6methodSfyFTJpSpSr4main05_AD__edfG24F_bb3__Pred__src_0_wrt_0OTf1nc_n : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
128+
// CHECK: %[[#C40:]] = partial_apply [callee_guaranteed] %[[#C39]](%[[#C34]]) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
129+
// TRUNNER: release_value %[[#C37]] : $@callee_guaranteed (Float) -> Class.TangentVector
130+
// CHECK: %[[#C42:]] = tuple (%[[#C33]] : $Float, %[[#C40]] : $@callee_guaranteed (Float) -> Class.TangentVector)
131+
// CHECK: return %[[#C42]]
132+
133+
%3 = float_literal $Builtin.FPIEEE32, 0x3F800000 // 1
134+
%4 = struct $Float (%3)
135+
%5 = struct_extract %0, #Class.stored
136+
%6 = struct_extract %5, #Float._value
137+
%7 = builtin "fmul_FPIEEE32"(%3, %6) : $Builtin.FPIEEE32
138+
%8 = struct $Float (%7)
139+
// function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:)
140+
%9 = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float)
141+
%10 = partial_apply [callee_guaranteed] %9(%5, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
142+
// function_ref autodiff subset parameters thunk for pullback from @escaping @callee_guaranteed (@unowned Float) -> (@unowned Float, @unowned Float)
143+
%11 = function_ref @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
144+
%12 = partial_apply [callee_guaranteed] %11(%10) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
145+
%13 = struct_extract %0, #Class.optional
146+
// function_ref pullback of Class.init(stored:optional:)
147+
%26 = function_ref @$s4test5ClassV6stored8optionalACSf_SfSgtcfCTJpSSUpSr : $@convention(thin) (Class.TangentVector) -> (Float, Optional<Float>.TangentVector)
148+
%27 = thin_to_thick_function %26 to $@callee_guaranteed (Class.TangentVector) -> (Float, Optional<Float>.TangentVector)
149+
%28 = tuple (%12, %27)
150+
switch_enum %13, case #Optional.some!enumelt: bb1, case #Optional.none!enumelt: bb2
151+
152+
bb1(%30 : $Float):
153+
%31 = enum $_AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, #_AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0.bb0!enumelt, %28
154+
%33 = struct_extract %30, #Float._value
155+
%34 = builtin "fmul_FPIEEE32"(%33, %7) : $Builtin.FPIEEE32
156+
%35 = struct $Float (%34)
157+
%36 = partial_apply [callee_guaranteed] %9(%8, %30) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
158+
%37 = tuple $(predecessor: _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> (Float, Float)) (%31, %36)
159+
%38 = enum $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, #_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0.bb1!enumelt, %37
160+
br bb3(%35, %38)
161+
162+
bb2:
163+
%40 = enum $_AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, #_AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0.bb0!enumelt, %28
164+
%41 = builtin "fmul_FPIEEE32"(%3, %7) : $Builtin.FPIEEE32
165+
%42 = struct $Float (%41)
166+
%43 = partial_apply [callee_guaranteed] %9(%8, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
167+
%44 = partial_apply [callee_guaranteed] %11(%43) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
168+
%45 = tuple $(predecessor: _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float) (%40, %44)
169+
%46 = enum $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, #_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0.bb2!enumelt, %45
170+
br bb3(%42, %46)
171+
172+
bb3(%48 : $Float, %49 : $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0):
173+
// function_ref pullback of Class.method()
174+
%50 = function_ref @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
175+
%51 = partial_apply [callee_guaranteed] %50(%49) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
176+
// function_ref pullback of methodWrapper(_:)
177+
%52 = function_ref @$s4test13methodWrapperySfAA5ClassVFTJpSpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) -> Class.TangentVector
178+
%53 = partial_apply [callee_guaranteed] %52(%51) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) -> Class.TangentVector
179+
%54 = tuple (%48, %53)
180+
return %54
181+
} // end sil function '$s4test13methodWrapperySfAA5ClassVFTJrSpSr'

0 commit comments

Comments
 (0)