|
| 1 | +// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -sil-print-after=differentiation -o /dev/null 2>&1 %s | %FileCheck %s -check-prefix=CHECK-SIL |
| 2 | +// REQUIRES: asserts |
| 3 | + |
| 4 | +// Simple generated derivative code FileCheck tests. |
| 5 | + |
| 6 | +import _Differentiation |
| 7 | + |
| 8 | +extension Float { |
| 9 | + @_silgen_name("add") |
| 10 | + static func add(_ x: Float, _ y: Float) -> Float { |
| 11 | + return x + y |
| 12 | + } |
| 13 | + |
| 14 | + @derivative(of: add) |
| 15 | + static func addVJP(_ x: Float, _ y: Float) -> ( |
| 16 | + value: Float, pullback: (Float) -> (Float, Float) |
| 17 | + ) { |
| 18 | + return (add(x, y), { v in (v, v) }) |
| 19 | + } |
| 20 | +} |
| 21 | + |
| 22 | +@_silgen_name("foo") |
| 23 | +@differentiable |
| 24 | +func foo(_ x: Float) -> Float { |
| 25 | + let y = Float.add(x, x) |
| 26 | + return y |
| 27 | +} |
| 28 | + |
| 29 | +// CHECK-SIL-LABEL: sil hidden [ossa] @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { |
| 30 | +// CHECK-SIL: bb0([[X:%.*]] : $Float): |
| 31 | +// CHECK-SIL: [[ADD_ORIG_REF:%.*]] = function_ref @add : $@convention(method) (Float, Float, @thin Float.Type) -> Float |
| 32 | +// CHECK-SIL: [[ADD_JVP_REF:%.*]] = differentiability_witness_function [jvp] [parameters 0 1] [results 0] @add |
| 33 | +// CHECK-SIL: [[ADD_VJP_REF:%.*]] = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @add |
| 34 | +// CHECK-SIL: [[ADD_DIFF_FN:%.*]] = differentiable_function [parameters 0 1] [[ADD_ORIG_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with_derivative {[[ADD_JVP_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[ADD_VJP_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))} |
| 35 | +// CHECK-SIL: [[ADD_VJP_FN:%.*]] = differentiable_function_extract [vjp] [[ADD_DIFF_FN]] |
| 36 | +// CHECK-SIL: end_borrow [[ADD_DIFF_FN]] |
| 37 | +// CHECK-SIL: [[ADD_RESULT:%.*]] = apply [[ADD_VJP_FN]]([[X]], [[X]], {{.*}}) |
| 38 | +// CHECK-SIL: ([[ORIG_RES:%.*]], [[ADD_PB:%.*]]) = destructure_tuple [[ADD_RESULT]] |
| 39 | +// CHECK-SIL: [[PB_STRUCT:%.*]] = struct $_AD__foo_bb0__PB__src_0_wrt_0 ([[ADD_PB]] : $@callee_guaranteed (Float) -> (Float, Float)) |
| 40 | +// CHECK-SIL: [[PB_REF:%.*]] = function_ref @AD__foo__pullback_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__foo_bb0__PB__src_0_wrt_0) -> Float |
| 41 | +// CHECK-SIL: [[PB_FN:%.*]] = partial_apply [callee_guaranteed] [[PB_REF]]([[PB_STRUCT]]) |
| 42 | +// CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB_FN]] : $@callee_guaranteed (Float) -> Float) |
| 43 | +// CHECK-SIL: return [[VJP_RESULT]] : $(Float, @callee_guaranteed (Float) -> Float) |
| 44 | +// CHECK-SIL: } |
| 45 | + |
| 46 | +// CHECK-SIL-LABEL: sil private [ossa] @AD__foo__pullback_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__foo_bb0__PB__src_0_wrt_0) -> Float { |
| 47 | +// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[PB_STRUCT:%.*]] : @owned $_AD__foo_bb0__PB__src_0_wrt_0): |
| 48 | +// CHECK-SIL: [[ADD_PB:%.*]] = destructure_struct [[PB_STRUCT]] : $_AD__foo_bb0__PB__src_0_wrt_0 |
| 49 | +// CHECK-SIL: [[ADD_PB_RES:%.*]] = apply [[ADD_PB]]([[SEED]]) : $@callee_guaranteed (Float) -> (Float, Float) |
| 50 | +// CHECK-SIL: ([[DX_1:%.*]], [[DX_2:%.*]]) = destructure_tuple [[ADD_PB_RES]] : $(Float, Float) |
| 51 | +// CHECK-SIL: [[TMP_BUF_RES:%.*]] = alloc_stack $Float |
| 52 | +// CHECK-SIL: [[TMP_BUF_LHS:%.*]] = alloc_stack $Float |
| 53 | +// CHECK-SIL: [[TMP_BUF_RHS:%.*]] = alloc_stack $Float |
| 54 | +// CHECK-SIL: store [[DX_1]] to [trivial] [[TMP_BUF_LHS]] : $*Float |
| 55 | +// CHECK-SIL: store [[DX_2]] to [trivial] [[TMP_BUF_RHS]] : $*Float |
| 56 | +// CHECK-SIL: [[PLUS_FN:%.*]] = witness_method $Float, #AdditiveArithmetic."+" |
| 57 | +// CHECK-SIL: apply [[PLUS_FN]]<Float>([[TMP_BUF_RES]], [[TMP_BUF_RHS]], [[TMP_BUF_LHS]], {{.*}}) |
| 58 | +// CHECK-SIL: destroy_addr [[TMP_BUF_LHS]] : $*Float |
| 59 | +// CHECK-SIL: destroy_addr [[TMP_BUF_RHS]] : $*Float |
| 60 | +// CHECK-SIL: dealloc_stack [[TMP_BUF_RHS]] : $*Float |
| 61 | +// CHECK-SIL: dealloc_stack [[TMP_BUF_LHS]] : $*Float |
| 62 | +// CHECK-SIL: [[DX:%.*]] = load [trivial] [[TMP_BUF_RES]] : $*Float |
| 63 | +// CHECK-SIL: dealloc_stack [[TMP_BUF_RES]] : $*Float |
| 64 | +// CHECK-SIL: return [[DX]] : $Float |
| 65 | +// CHECK-SIL: } |
0 commit comments