Skip to content

Commit d6bbf97

Browse files
committed
Add simple generated derivative code FileCheck test.
1 parent b833271 commit d6bbf97

File tree

1 file changed

+65
-0
lines changed

1 file changed

+65
-0
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -sil-print-after=differentiation -o /dev/null 2>&1 %s | %FileCheck %s -check-prefix=CHECK-SIL
2+
// REQUIRES: asserts
3+
4+
// Simple generated derivative code FileCheck tests.
5+
6+
import _Differentiation
7+
8+
extension Float {
9+
@_silgen_name("add")
10+
static func add(_ x: Float, _ y: Float) -> Float {
11+
return x + y
12+
}
13+
14+
@derivative(of: add)
15+
static func addVJP(_ x: Float, _ y: Float) -> (
16+
value: Float, pullback: (Float) -> (Float, Float)
17+
) {
18+
return (add(x, y), { v in (v, v) })
19+
}
20+
}
21+
22+
@_silgen_name("foo")
23+
@differentiable
24+
func foo(_ x: Float) -> Float {
25+
let y = Float.add(x, x)
26+
return y
27+
}
28+
29+
// CHECK-SIL-LABEL: sil hidden [ossa] @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
30+
// CHECK-SIL: bb0([[X:%.*]] : $Float):
31+
// CHECK-SIL: [[ADD_ORIG_REF:%.*]] = function_ref @add : $@convention(method) (Float, Float, @thin Float.Type) -> Float
32+
// CHECK-SIL: [[ADD_JVP_REF:%.*]] = differentiability_witness_function [jvp] [parameters 0 1] [results 0] @add
33+
// CHECK-SIL: [[ADD_VJP_REF:%.*]] = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @add
34+
// CHECK-SIL: [[ADD_DIFF_FN:%.*]] = differentiable_function [parameters 0 1] [[ADD_ORIG_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with_derivative {[[ADD_JVP_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[ADD_VJP_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))}
35+
// CHECK-SIL: [[ADD_VJP_FN:%.*]] = differentiable_function_extract [vjp] [[ADD_DIFF_FN]]
36+
// CHECK-SIL: end_borrow [[ADD_DIFF_FN]]
37+
// CHECK-SIL: [[ADD_RESULT:%.*]] = apply [[ADD_VJP_FN]]([[X]], [[X]], {{.*}})
38+
// CHECK-SIL: ([[ORIG_RES:%.*]], [[ADD_PB:%.*]]) = destructure_tuple [[ADD_RESULT]]
39+
// CHECK-SIL: [[PB_STRUCT:%.*]] = struct $_AD__foo_bb0__PB__src_0_wrt_0 ([[ADD_PB]] : $@callee_guaranteed (Float) -> (Float, Float))
40+
// CHECK-SIL: [[PB_REF:%.*]] = function_ref @AD__foo__pullback_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__foo_bb0__PB__src_0_wrt_0) -> Float
41+
// CHECK-SIL: [[PB_FN:%.*]] = partial_apply [callee_guaranteed] [[PB_REF]]([[PB_STRUCT]])
42+
// CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB_FN]] : $@callee_guaranteed (Float) -> Float)
43+
// CHECK-SIL: return [[VJP_RESULT]] : $(Float, @callee_guaranteed (Float) -> Float)
44+
// CHECK-SIL: }
45+
46+
// CHECK-SIL-LABEL: sil private [ossa] @AD__foo__pullback_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__foo_bb0__PB__src_0_wrt_0) -> Float {
47+
// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[PB_STRUCT:%.*]] : @owned $_AD__foo_bb0__PB__src_0_wrt_0):
48+
// CHECK-SIL: [[ADD_PB:%.*]] = destructure_struct [[PB_STRUCT]] : $_AD__foo_bb0__PB__src_0_wrt_0
49+
// CHECK-SIL: [[ADD_PB_RES:%.*]] = apply [[ADD_PB]]([[SEED]]) : $@callee_guaranteed (Float) -> (Float, Float)
50+
// CHECK-SIL: ([[DX_1:%.*]], [[DX_2:%.*]]) = destructure_tuple [[ADD_PB_RES]] : $(Float, Float)
51+
// CHECK-SIL: [[TMP_BUF_RES:%.*]] = alloc_stack $Float
52+
// CHECK-SIL: [[TMP_BUF_LHS:%.*]] = alloc_stack $Float
53+
// CHECK-SIL: [[TMP_BUF_RHS:%.*]] = alloc_stack $Float
54+
// CHECK-SIL: store [[DX_1]] to [trivial] [[TMP_BUF_LHS]] : $*Float
55+
// CHECK-SIL: store [[DX_2]] to [trivial] [[TMP_BUF_RHS]] : $*Float
56+
// CHECK-SIL: [[PLUS_FN:%.*]] = witness_method $Float, #AdditiveArithmetic."+"
57+
// CHECK-SIL: apply [[PLUS_FN]]<Float>([[TMP_BUF_RES]], [[TMP_BUF_RHS]], [[TMP_BUF_LHS]], {{.*}})
58+
// CHECK-SIL: destroy_addr [[TMP_BUF_LHS]] : $*Float
59+
// CHECK-SIL: destroy_addr [[TMP_BUF_RHS]] : $*Float
60+
// CHECK-SIL: dealloc_stack [[TMP_BUF_RHS]] : $*Float
61+
// CHECK-SIL: dealloc_stack [[TMP_BUF_LHS]] : $*Float
62+
// CHECK-SIL: [[DX:%.*]] = load [trivial] [[TMP_BUF_RES]] : $*Float
63+
// CHECK-SIL: dealloc_stack [[TMP_BUF_RES]] : $*Float
64+
// CHECK-SIL: return [[DX]] : $Float
65+
// CHECK-SIL: }

0 commit comments

Comments
 (0)