|
| 1 | +// RUN: %target-swift-frontend -enable-experimental-differentiable-programming -emit-ir %s | %FileCheck %s |
| 2 | + |
| 3 | +sil_stage raw |
| 4 | + |
| 5 | +import Swift |
| 6 | +import Builtin |
| 7 | +import _Differentiation |
| 8 | + |
| 9 | +sil @f : $@convention(thin) (Float) -> Float { |
| 10 | +bb0(%0 : $Float): |
| 11 | + return undef : $Float |
| 12 | +} |
| 13 | + |
| 14 | +sil @f_jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { |
| 15 | +bb0(%0 : $Float): |
| 16 | + return undef : $(Float, @callee_guaranteed (Float) -> Float) |
| 17 | +} |
| 18 | + |
| 19 | +sil @f_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { |
| 20 | +bb0(%0 : $Float): |
| 21 | + return undef : $(Float, @callee_guaranteed (Float) -> Float) |
| 22 | +} |
| 23 | + |
| 24 | +sil @test_form_diff_func : $@convention(thin) () -> @owned @differentiable @callee_guaranteed (Float) -> Float { |
| 25 | +bb0: |
| 26 | + %orig = function_ref @f : $@convention(thin) (Float) -> Float |
| 27 | + %origThick = thin_to_thick_function %orig : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float |
| 28 | + %jvp = function_ref @f_jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) |
| 29 | + %jvpThick = thin_to_thick_function %jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) to $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) |
| 30 | + %vjp = function_ref @f_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) |
| 31 | + %vjpThick = thin_to_thick_function %vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) to $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) |
| 32 | + %result = differentiable_function [parameters 0] %origThick : $@callee_guaranteed (Float) -> Float with_derivative {%jvpThick : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjpThick : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} |
| 33 | + return %result : $@differentiable @callee_guaranteed (Float) -> Float |
| 34 | +} |
| 35 | + |
| 36 | +// CHECK-LABEL: define{{.*}}test_form_diff_func(<{ %swift.function, %swift.function, %swift.function }>* |
| 37 | +// CHECK-SAME: [[OUT:%.*]]) |
| 38 | +// CHECK: [[OUT_ORIG:%.*]] = getelementptr{{.*}}[[OUT]], i32 0, i32 0 |
| 39 | +// CHECK: [[OUT_ORIG_FN:%.*]] = getelementptr{{.*}}[[OUT_ORIG]], i32 0, i32 0 |
| 40 | +// CHECK: store{{.*}}@f{{.*}}[[OUT_ORIG_FN]] |
| 41 | +// CHECK: [[OUT_ORIG_DATA:%.*]] = getelementptr{{.*}} [[OUT_ORIG]], i32 0, i32 1 |
| 42 | +// CHECK: store{{.*}}null{{.*}} [[OUT_ORIG_DATA]] |
| 43 | +// CHECK: [[OUT_JVP:%.*]] = getelementptr{{.*}}[[OUT]], i32 0, i32 1 |
| 44 | +// CHECK: [[OUT_JVP_FN:%.*]] = getelementptr{{.*}}[[OUT_JVP]], i32 0, i32 0 |
| 45 | +// CHECK: store{{.*}}@f_jvp{{.*}}[[OUT_JVP_FN]] |
| 46 | +// CHECK: [[OUT_JVP_DATA:%.*]] = getelementptr{{.*}}[[OUT_JVP]], i32 0, i32 1 |
| 47 | +// CHECK: store{{.*}} null{{.*}}[[OUT_JVP_DATA]] |
| 48 | +// CHECK: [[OUT_VJP:%.*]] = getelementptr{{.*}}[[OUT]], i32 0, i32 2 |
| 49 | +// CHECK: [[OUT_VJP_FN:%.*]] = getelementptr{{.*}}[[OUT_VJP]], i32 0, i32 0 |
| 50 | +// CHECK: store{{.*}}@f_vjp{{.*}}[[OUT_VJP_FN]] |
| 51 | +// CHECK: [[OUT_VJP_DATA:%.*]] = getelementptr{{.*}}[[OUT_VJP]], i32 0, i32 1 |
| 52 | +// CHECK: store{{.*}}null{{.*}}[[OUT_VJP_DATA]] |
| 53 | + |
| 54 | +sil @test_extract_components : $@convention(thin) (@guaranteed @differentiable @callee_guaranteed (Float) -> Float) -> (@owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), @owned @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) { |
| 55 | +bb0(%0 : $@differentiable @callee_guaranteed (Float) -> Float): |
| 56 | + %orig = differentiable_function_extract [original] %0 : $@differentiable @callee_guaranteed (Float) -> Float |
| 57 | + %jvp = differentiable_function_extract [jvp] %0 : $@differentiable @callee_guaranteed (Float) -> Float |
| 58 | + %vjp = differentiable_function_extract [vjp] %0 : $@differentiable @callee_guaranteed (Float) -> Float |
| 59 | + %result = tuple (%orig : $@callee_guaranteed (Float) -> Float, %jvp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) |
| 60 | + return %result : $(@callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) |
| 61 | +} |
| 62 | + |
| 63 | +// CHECK-LABEL: define{{.*}}@test_extract_components(<{ %swift.function, %swift.function, %swift.function }>* |
| 64 | +// CHECK-SAME: [[OUT:%.*]], <{ %swift.function, %swift.function, %swift.function }>*{{.*}}[[IN:%.*]]) |
| 65 | +// CHECK: [[ORIG:%.*]] = getelementptr{{.*}}[[IN]], i32 0, i32 0 |
| 66 | +// CHECK: [[ORIG_FN_ADDR:%.*]] = getelementptr{{.*}}[[ORIG]], i32 0, i32 0 |
| 67 | +// CHECK: [[ORIG_FN:%.*]] = load{{.*}}[[ORIG_FN_ADDR]] |
| 68 | +// CHECK: [[ORIG_DATA_ADDR:%.*]] = getelementptr{{.*}}[[ORIG]], i32 0, i32 1 |
| 69 | +// CHECK: [[ORIG_DATA:%.*]] = load{{.*}}[[ORIG_DATA_ADDR]] |
| 70 | +// CHECK: [[JVP:%.*]] = getelementptr{{.*}}[[IN]], i32 0, i32 1 |
| 71 | +// CHECK: [[JVP_FN_ADDR:%.*]] = getelementptr{{.*}}[[JVP]], i32 0, i32 0 |
| 72 | +// CHECK: [[JVP_FN:%.*]] = load{{.*}}[[JVP_FN_ADDR]] |
| 73 | +// CHECK: [[JVP_DATA_ADDR:%.*]] = getelementptr{{.*}}[[JVP]], i32 0, i32 1 |
| 74 | +// CHECK: [[JVP_DATA:%.*]] = load{{.*}}[[JVP_DATA_ADDR]] |
| 75 | +// CHECK: [[VJP:%.*]] = getelementptr{{.*}}[[IN]], i32 0, i32 2 |
| 76 | +// CHECK: [[VJP_FN_ADDR:%.*]] = getelementptr{{.*}}[[VJP]], i32 0, i32 0 |
| 77 | +// CHECK: [[VJP_FN:%.*]] = load{{.*}}[[VJP_FN_ADDR]] |
| 78 | +// CHECK: [[VJP_DATA_ADDR:%.*]] = getelementptr{{.*}}[[VJP]], i32 0, i32 1 |
| 79 | +// CHECK: [[VJP_DATA:%.*]] = load{{.*}}[[VJP_DATA_ADDR]] |
| 80 | +// CHECK: [[OUT_0:%.*]] = getelementptr{{.*}}[[OUT]], i32 0, i32 0 |
| 81 | +// CHECK: [[OUT_0_FN:%.*]] = getelementptr{{.*}}[[OUT_0]], i32 0, i32 0 |
| 82 | +// CHECK: store{{.*}}[[ORIG_FN]]{{.*}}[[OUT_0_FN]] |
| 83 | +// CHECK: [[OUT_0_DATA:%.*]] = getelementptr{{.*}}[[OUT_0]], i32 0, i32 1 |
| 84 | +// CHECK: store{{.*}}[[ORIG_DATA]]{{.*}}[[OUT_0_DATA]] |
| 85 | +// CHECK: [[OUT_1:%.*]] = getelementptr{{.*}}[[OUT]], i32 0, i32 1 |
| 86 | +// CHECK: [[OUT_1_FN:%.*]] = getelementptr{{.*}}[[OUT_1]], i32 0, i32 0 |
| 87 | +// CHECK: store{{.*}}[[JVP_FN]]{{.*}}[[OUT_1_FN]] |
| 88 | +// CHECK: [[OUT_1_DATA:%.*]] = getelementptr{{.*}}[[OUT_1]], i32 0, i32 1 |
| 89 | +// CHECK: store{{.*}}[[JVP_DATA]]{{.*}}[[OUT_1_DATA]] |
| 90 | +// CHECK: [[OUT_2:%.*]] = getelementptr{{.*}}[[OUT]], i32 0, i32 2 |
| 91 | +// CHECK: [[OUT_2_FN:%.*]] = getelementptr{{.*}}[[OUT_2]], i32 0, i32 0 |
| 92 | +// CHECK: store{{.*}}[[VJP_FN]]{{.*}}[[OUT_2_FN]] |
| 93 | +// CHECK: [[OUT_2_DATA:%.*]] = getelementptr{{.*}}[[OUT_2]], i32 0, i32 1 |
| 94 | +// CHECK: store{{.*}}[[VJP_DATA]]{{.*}}[[OUT_2_DATA]] |
| 95 | + |
| 96 | +sil @test_call_diff_fn : $@convention(thin) (@guaranteed @differentiable @callee_guaranteed (Float) -> Float, Float) -> Float { |
| 97 | +bb0(%0 : $@differentiable @callee_guaranteed (Float) -> Float, %1 : $Float): |
| 98 | + %2 = apply %0(%1) : $@differentiable @callee_guaranteed (Float) -> Float |
| 99 | + return %2 : $Float |
| 100 | +} |
| 101 | + |
| 102 | +// CHECK-LABEL: define{{.*}}@test_call_diff_fn(<{ %swift.function, %swift.function, %swift.function }>* |
| 103 | +// CHECK-SAME: [[DIFF_FN:%.*]], float [[INPUT_FLOAT:%.*]]) |
| 104 | +// CHECK: [[ORIG:%.*]] = getelementptr{{.*}}[[DIFF_FN]], i32 0, i32 0 |
| 105 | +// CHECK: [[ORIG_FN_ADDR:%.*]] = getelementptr{{.*}}[[ORIG]], i32 0, i32 0 |
| 106 | +// CHECK: [[ORIG_FN:%.*]] = load{{.*}}[[ORIG_FN_ADDR]] |
| 107 | +// CHECK: [[ORIG_DATA_ADDR:%.*]] = getelementptr{{.*}}[[ORIG]], i32 0, i32 1 |
| 108 | +// CHECK: [[ORIG_DATA:%.*]] = load{{.*}}[[ORIG_DATA_ADDR]] |
| 109 | +// CHECK: [[ORIG_FN_CAST:%.*]] = bitcast{{.*}}[[ORIG_FN]] |
| 110 | +// CHECK: [[RESULT:%.*]] = call swiftcc float [[ORIG_FN_CAST]](float [[INPUT_FLOAT]], %swift.refcounted* swiftself [[ORIG_DATA]]) |
| 111 | +// CHECK: ret float [[RESULT]] |
| 112 | + |
| 113 | +sil @test_convert_escape_to_noescape : $@convention(thin) (@guaranteed @differentiable @callee_guaranteed (Float) -> Float) -> () { |
| 114 | +bb0(%0 : $@differentiable @callee_guaranteed (Float) -> Float): |
| 115 | + %1 = convert_escape_to_noescape %0 : $@differentiable @callee_guaranteed (Float) -> Float to $@noescape @differentiable @callee_guaranteed (Float) -> Float |
| 116 | + return undef : $() |
| 117 | +} |
0 commit comments