|
| 1 | +/// Multi basic block VJP, pullback not accepting branch tracing enum argument. |
| 2 | + |
| 3 | +// RUN: %target-sil-opt -sil-print-types -test-runner %s -o /dev/null 2>&1 | %FileCheck %s --check-prefixes=TRUNNER,CHECK |
| 4 | +// RUN: %target-sil-opt -sil-print-types -autodiff-closure-specialization -sil-combine %s -o - | %FileCheck %s --check-prefixes=COMBINE,CHECK |
| 5 | + |
| 6 | +// REQUIRES: swift_in_compiler |
| 7 | + |
| 8 | +sil_stage canonical |
| 9 | + |
| 10 | +import Builtin |
| 11 | +import Swift |
| 12 | +import SwiftShims |
| 13 | + |
| 14 | +import _Differentiation |
| 15 | + |
| 16 | +/// This SIL corresponds to the following Swift: |
| 17 | +/// |
| 18 | +/// struct Class: Differentiable { |
| 19 | +/// var stored: Float |
| 20 | +/// var optional: Float? |
| 21 | +/// |
| 22 | +/// init(stored: Float, optional: Float?) { |
| 23 | +/// self.stored = stored |
| 24 | +/// self.optional = optional |
| 25 | +/// } |
| 26 | +/// |
| 27 | +/// @differentiable(reverse) |
| 28 | +/// func method() -> Float { |
| 29 | +/// let c: Class |
| 30 | +/// do { |
| 31 | +/// let tmp = Class(stored: 1 * stored, optional: optional) |
| 32 | +/// let tuple = (tmp, tmp) |
| 33 | +/// c = tuple.0 |
| 34 | +/// } |
| 35 | +/// if let x = c.optional { |
| 36 | +/// return x * c.stored |
| 37 | +/// } |
| 38 | +/// return 1 * c.stored |
| 39 | +/// } |
| 40 | +/// } |
| 41 | +/// |
| 42 | +/// @differentiable(reverse) |
| 43 | +/// func methodWrapper(_ x: Class) -> Float { |
| 44 | +/// x.method() |
| 45 | +/// } |
| 46 | + |
| 47 | +struct Class : Differentiable { |
| 48 | + @_hasStorage var stored: Float { get set } |
| 49 | + @_hasStorage @_hasInitialValue var optional: Float? { get set } |
| 50 | + init(stored: Float, optional: Float?) |
| 51 | + @differentiable(reverse, wrt: self) |
| 52 | + func method() -> Float |
| 53 | + struct TangentVector : AdditiveArithmetic, Differentiable { |
| 54 | + @_hasStorage var stored: Float { get set } |
| 55 | + @_hasStorage var optional: Optional<Float>.TangentVector { get set } |
| 56 | + static func + (lhs: Class.TangentVector, rhs: Class.TangentVector) -> Class.TangentVector |
| 57 | + static func - (lhs: Class.TangentVector, rhs: Class.TangentVector) -> Class.TangentVector |
| 58 | + typealias TangentVector = Class.TangentVector |
| 59 | + @_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: Class.TangentVector, _ b: Class.TangentVector) -> Bool |
| 60 | + init(stored: Float, optional: Optional<Float>.TangentVector) |
| 61 | + static var zero: Class.TangentVector { get } |
| 62 | + } |
| 63 | + mutating func move(by offset: Class.TangentVector) |
| 64 | +} |
| 65 | + |
| 66 | +enum _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0 { |
| 67 | + case bb0(((Float) -> Float, (Class.TangentVector) -> (Float, Optional<Float>.TangentVector))) |
| 68 | +} |
| 69 | + |
| 70 | +enum _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0 { |
| 71 | + case bb0(((Float) -> Float, (Class.TangentVector) -> (Float, Optional<Float>.TangentVector))) |
| 72 | +} |
| 73 | + |
| 74 | +enum _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0 { |
| 75 | + case bb2((predecessor: _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, (Float) -> Float)) |
| 76 | + case bb1((predecessor: _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, (Float) -> (Float, Float))) |
| 77 | +} |
| 78 | + |
| 79 | +sil @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) |
| 80 | +sil [transparent] [thunk] @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float |
| 81 | +sil @$s4test5ClassV6stored8optionalACSf_SfSgtcfCTJpSSUpSr : $@convention(thin) (Class.TangentVector) -> (Float, Optional<Float>.TangentVector) |
| 82 | +sil @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector |
| 83 | + |
| 84 | +// pullback of methodWrapper(_:) |
| 85 | +sil private [signature_optimized_thunk] [always_inline] @$s4test13methodWrapperySfAA5ClassVFTJpSpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) -> Class.TangentVector { |
| 86 | +bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> Class.TangentVector): |
| 87 | + %2 = apply %1(%0) : $@callee_guaranteed (Float) -> Class.TangentVector |
| 88 | + strong_release %1 |
| 89 | + return %2 |
| 90 | +} // end sil function '$s4test13methodWrapperySfAA5ClassVFTJpSpSr' |
| 91 | + |
| 92 | +// reverse-mode derivative of methodWrapper(_:) |
| 93 | +sil hidden @$s4test13methodWrapperySfAA5ClassVFTJrSpSr : $@convention(thin) (Class) -> (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) { |
| 94 | +bb0(%0 : $Class): |
| 95 | + //=========== Test callsite and closure gathering logic ===========// |
| 96 | + specify_test "autodiff_closure_specialize_get_pullback_closure_info" |
| 97 | + // TRUNNER-LABEL: Specializing closures in function: $s4test13methodWrapperySfAA5ClassVFTJrSpSr |
| 98 | + // TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#A36:]]) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) -> Class.TangentVector |
| 99 | + // TRUNNER-NEXT: Passed in closures: |
| 100 | + // TRUNNER-NEXT: 1. %[[#A36]] = partial_apply [callee_guaranteed] %[[#]](%[[#]]) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector |
| 101 | + // TRUNNER-EMPTY: |
| 102 | + |
| 103 | + //=========== Test specialized function signature and body ===========// |
| 104 | + specify_test "autodiff_closure_specialize_specialized_function_signature_and_body" |
| 105 | + // TRUNNER-LABEL: Generated specialized function: $s4test13methodWrapperySfAA5ClassVFTJpSpSr08$s4test5D19V6methodSfyFTJpSpSr4main05_AD__edfG24F_bb3__Pred__src_0_wrt_0OTf1nc_n |
| 106 | + // CHECK: sil private [signature_optimized_thunk] [always_inline] @$s4test13methodWrapperySfAA5ClassVFTJpSpSr08$s4test5D19V6methodSfyFTJpSpSr4main05_AD__edfG24F_bb3__Pred__src_0_wrt_0OTf1nc_n : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector { |
| 107 | + // CHECK: bb0(%0 : $Float, %1 : $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0): |
| 108 | + // CHECK: %[[#B2:]] = function_ref @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector |
| 109 | + // TRUNNER: %[[#B3:]] = partial_apply [callee_guaranteed] %[[#B2]](%1) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector |
| 110 | + // TRUNNER: %[[#B4:]] = apply %[[#B3]](%0) : $@callee_guaranteed (Float) -> Class.TangentVector |
| 111 | + // COMBINE-NOT: partial_apply |
| 112 | + // COMBINE: %[[#B4:]] = apply %[[#B2]](%0, %1) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector |
| 113 | + // TRUNNER: strong_release %[[#B3]] : $@callee_guaranteed (Float) -> Class.TangentVector |
| 114 | + // CHECK: return %[[#B4]] |
| 115 | + |
| 116 | + //=========== Test rewritten body ===========// |
| 117 | + specify_test "autodiff_closure_specialize_rewritten_caller_body" |
| 118 | + // TRUNNER-LABEL: Rewritten caller body for: $s4test13methodWrapperySfAA5ClassVFTJrSpSr: |
| 119 | + // CHECK: sil hidden @$s4test13methodWrapperySfAA5ClassVFTJrSpSr : $@convention(thin) (Class) -> (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) { |
| 120 | + // CHECK: bb3(%[[#C33:]] : $Float, %[[#C34:]] : $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0): |
| 121 | + // TRUNNER: %[[#C35:]] = function_ref @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector |
| 122 | + // TRUNNER: %[[#C37:]] = partial_apply [callee_guaranteed] %[[#C35]](%[[#C34]]) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector |
| 123 | + // TRUNNER: %[[#C38:]] = function_ref @$s4test13methodWrapperySfAA5ClassVFTJpSpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) -> Class.TangentVector |
| 124 | + // COMBINE-NOT: function_ref @$s4test5ClassV6methodSfyFTJpSpSr |
| 125 | + // COMBINE-NOT: partial_apply |
| 126 | + // COMBINE-NOT: function_ref @$s4test13methodWrapperySfAA5ClassVFTJpSpSr |
| 127 | + // CHECK: %[[#C39:]] = function_ref @$s4test13methodWrapperySfAA5ClassVFTJpSpSr08$s4test5D19V6methodSfyFTJpSpSr4main05_AD__edfG24F_bb3__Pred__src_0_wrt_0OTf1nc_n : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector |
| 128 | + // CHECK: %[[#C40:]] = partial_apply [callee_guaranteed] %[[#C39]](%[[#C34]]) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector |
| 129 | + // TRUNNER: release_value %[[#C37]] : $@callee_guaranteed (Float) -> Class.TangentVector |
| 130 | + // CHECK: %[[#C42:]] = tuple (%[[#C33]] : $Float, %[[#C40]] : $@callee_guaranteed (Float) -> Class.TangentVector) |
| 131 | + // CHECK: return %[[#C42]] |
| 132 | + |
| 133 | + %3 = float_literal $Builtin.FPIEEE32, 0x3F800000 // 1 |
| 134 | + %4 = struct $Float (%3) |
| 135 | + %5 = struct_extract %0, #Class.stored |
| 136 | + %6 = struct_extract %5, #Float._value |
| 137 | + %7 = builtin "fmul_FPIEEE32"(%3, %6) : $Builtin.FPIEEE32 |
| 138 | + %8 = struct $Float (%7) |
| 139 | + // function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:) |
| 140 | + %9 = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) |
| 141 | + %10 = partial_apply [callee_guaranteed] %9(%5, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float) |
| 142 | + // function_ref autodiff subset parameters thunk for pullback from @escaping @callee_guaranteed (@unowned Float) -> (@unowned Float, @unowned Float) |
| 143 | + %11 = function_ref @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float |
| 144 | + %12 = partial_apply [callee_guaranteed] %11(%10) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float |
| 145 | + %13 = struct_extract %0, #Class.optional |
| 146 | + // function_ref pullback of Class.init(stored:optional:) |
| 147 | + %26 = function_ref @$s4test5ClassV6stored8optionalACSf_SfSgtcfCTJpSSUpSr : $@convention(thin) (Class.TangentVector) -> (Float, Optional<Float>.TangentVector) |
| 148 | + %27 = thin_to_thick_function %26 to $@callee_guaranteed (Class.TangentVector) -> (Float, Optional<Float>.TangentVector) |
| 149 | + %28 = tuple (%12, %27) |
| 150 | + switch_enum %13, case #Optional.some!enumelt: bb1, case #Optional.none!enumelt: bb2 |
| 151 | + |
| 152 | +bb1(%30 : $Float): |
| 153 | + %31 = enum $_AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, #_AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0.bb0!enumelt, %28 |
| 154 | + %33 = struct_extract %30, #Float._value |
| 155 | + %34 = builtin "fmul_FPIEEE32"(%33, %7) : $Builtin.FPIEEE32 |
| 156 | + %35 = struct $Float (%34) |
| 157 | + %36 = partial_apply [callee_guaranteed] %9(%8, %30) : $@convention(thin) (Float, Float, Float) -> (Float, Float) |
| 158 | + %37 = tuple $(predecessor: _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> (Float, Float)) (%31, %36) |
| 159 | + %38 = enum $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, #_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0.bb1!enumelt, %37 |
| 160 | + br bb3(%35, %38) |
| 161 | + |
| 162 | +bb2: |
| 163 | + %40 = enum $_AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, #_AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0.bb0!enumelt, %28 |
| 164 | + %41 = builtin "fmul_FPIEEE32"(%3, %7) : $Builtin.FPIEEE32 |
| 165 | + %42 = struct $Float (%41) |
| 166 | + %43 = partial_apply [callee_guaranteed] %9(%8, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float) |
| 167 | + %44 = partial_apply [callee_guaranteed] %11(%43) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float |
| 168 | + %45 = tuple $(predecessor: _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float) (%40, %44) |
| 169 | + %46 = enum $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, #_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0.bb2!enumelt, %45 |
| 170 | + br bb3(%42, %46) |
| 171 | + |
| 172 | +bb3(%48 : $Float, %49 : $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0): |
| 173 | + // function_ref pullback of Class.method() |
| 174 | + %50 = function_ref @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector |
| 175 | + %51 = partial_apply [callee_guaranteed] %50(%49) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector |
| 176 | + // function_ref pullback of methodWrapper(_:) |
| 177 | + %52 = function_ref @$s4test13methodWrapperySfAA5ClassVFTJpSpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) -> Class.TangentVector |
| 178 | + %53 = partial_apply [callee_guaranteed] %52(%51) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) -> Class.TangentVector |
| 179 | + %54 = tuple (%48, %53) |
| 180 | + return %54 |
| 181 | +} // end sil function '$s4test13methodWrapperySfAA5ClassVFTJrSpSr' |
0 commit comments